Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			L4
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			L4
	
		Aaryaman Vasishta
		
	commited on
		
		
					Commit 
							
							·
						
						38dbec8
	
1
								Parent(s):
							
							73dc205
								
Add spar3d demo files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +2 -0
- .gitignore +167 -0
- .pre-commit-config.yaml +24 -0
- LICENSE.md +51 -0
- README.md +3 -3
- __init__.py +358 -0
- demo_files/comp.gif +3 -0
- demo_files/examples/bird.png +3 -0
- demo_files/examples/castle.png +3 -0
- demo_files/examples/chest.png +3 -0
- demo_files/examples/doll.png +3 -0
- demo_files/examples/excavator.png +3 -0
- demo_files/examples/fish.png +3 -0
- demo_files/examples/horse-statue.png +3 -0
- demo_files/examples/penguin.png +3 -0
- demo_files/examples/pot.png +3 -0
- demo_files/examples/raccoon_wizard.png +3 -0
- demo_files/examples/stylized-rocks.png +3 -0
- demo_files/hdri/abandoned_tiled_room_1k.hdr +0 -0
- demo_files/hdri/metro_noord_1k.hdr +0 -0
- demo_files/hdri/neon_photostudio_1k.hdr +0 -0
- demo_files/hdri/peppermint_powerplant_1k.hdr +0 -0
- demo_files/hdri/rainforest_trail_1k.hdr +0 -0
- demo_files/hdri/studio_small_08_1k.hdr +0 -0
- demo_files/hdri/urban_alley_01_1k.hdr +0 -0
- demo_files/turntable.gif +3 -0
- demo_files/workflows/spar3d_example.json +263 -0
- gradio_app.py +792 -0
- load/tets/160_tets.npz +3 -0
- requirements.txt +17 -0
- ruff.toml +3 -0
- run.py +180 -0
- spar3d/models/camera.py +32 -0
- spar3d/models/diffusion/gaussian_diffusion.py +524 -0
- spar3d/models/diffusion/sampler.py +134 -0
- spar3d/models/global_estimator/reni_estimator.py +112 -0
- spar3d/models/illumination/reni/components/film_siren.py +148 -0
- spar3d/models/illumination/reni/components/siren.py +118 -0
- spar3d/models/illumination/reni/components/transformer_decoder.py +189 -0
- spar3d/models/illumination/reni/components/vn_layers.py +548 -0
- spar3d/models/illumination/reni/env_map.py +93 -0
- spar3d/models/illumination/reni/field.py +736 -0
- spar3d/models/image_estimator/clip_based_estimator.py +184 -0
- spar3d/models/isosurface.py +229 -0
- spar3d/models/mesh.py +317 -0
- spar3d/models/network.py +223 -0
- spar3d/models/tokenizers/dinov2.py +1196 -0
- spar3d/models/tokenizers/image.py +99 -0
- spar3d/models/tokenizers/point.py +51 -0
- spar3d/models/tokenizers/triplane.py +49 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            *.gif filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            *.png filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,167 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Byte-compiled / optimized / DLL files
         | 
| 2 | 
            +
            __pycache__/
         | 
| 3 | 
            +
            *.py[cod]
         | 
| 4 | 
            +
            *$py.class
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # C extensions
         | 
| 7 | 
            +
            *.so
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Distribution / packaging
         | 
| 10 | 
            +
            .Python
         | 
| 11 | 
            +
            build/
         | 
| 12 | 
            +
            develop-eggs/
         | 
| 13 | 
            +
            dist/
         | 
| 14 | 
            +
            downloads/
         | 
| 15 | 
            +
            eggs/
         | 
| 16 | 
            +
            .eggs/
         | 
| 17 | 
            +
            lib/
         | 
| 18 | 
            +
            lib64/
         | 
| 19 | 
            +
            parts/
         | 
| 20 | 
            +
            sdist/
         | 
| 21 | 
            +
            var/
         | 
| 22 | 
            +
            wheels/
         | 
| 23 | 
            +
            share/python-wheels/
         | 
| 24 | 
            +
            *.egg-info/
         | 
| 25 | 
            +
            .installed.cfg
         | 
| 26 | 
            +
            *.egg
         | 
| 27 | 
            +
            MANIFEST
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            # PyInstaller
         | 
| 30 | 
            +
            #  Usually these files are written by a python script from a template
         | 
| 31 | 
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         | 
| 32 | 
            +
            *.manifest
         | 
| 33 | 
            +
            *.spec
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            # Installer logs
         | 
| 36 | 
            +
            pip-log.txt
         | 
| 37 | 
            +
            pip-delete-this-directory.txt
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            # Unit test / coverage reports
         | 
| 40 | 
            +
            htmlcov/
         | 
| 41 | 
            +
            .tox/
         | 
| 42 | 
            +
            .nox/
         | 
| 43 | 
            +
            .coverage
         | 
| 44 | 
            +
            .coverage.*
         | 
| 45 | 
            +
            .cache
         | 
| 46 | 
            +
            nosetests.xml
         | 
| 47 | 
            +
            coverage.xml
         | 
| 48 | 
            +
            *.cover
         | 
| 49 | 
            +
            *.py,cover
         | 
| 50 | 
            +
            .hypothesis/
         | 
| 51 | 
            +
            .pytest_cache/
         | 
| 52 | 
            +
            cover/
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            # Translations
         | 
| 55 | 
            +
            *.mo
         | 
| 56 | 
            +
            *.pot
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            # Django stuff:
         | 
| 59 | 
            +
            *.log
         | 
| 60 | 
            +
            local_settings.py
         | 
| 61 | 
            +
            db.sqlite3
         | 
| 62 | 
            +
            db.sqlite3-journal
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            # Flask stuff:
         | 
| 65 | 
            +
            instance/
         | 
| 66 | 
            +
            .webassets-cache
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            # Scrapy stuff:
         | 
| 69 | 
            +
            .scrapy
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            # Sphinx documentation
         | 
| 72 | 
            +
            docs/_build/
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            # PyBuilder
         | 
| 75 | 
            +
            .pybuilder/
         | 
| 76 | 
            +
            target/
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            # Jupyter Notebook
         | 
| 79 | 
            +
            .ipynb_checkpoints
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            # IPython
         | 
| 82 | 
            +
            profile_default/
         | 
| 83 | 
            +
            ipython_config.py
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            # pyenv
         | 
| 86 | 
            +
            #   For a library or package, you might want to ignore these files since the code is
         | 
| 87 | 
            +
            #   intended to run in multiple environments; otherwise, check them in:
         | 
| 88 | 
            +
            # .python-version
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            # pipenv
         | 
| 91 | 
            +
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         | 
| 92 | 
            +
            #   However, in case of collaboration, if having platform-specific dependencies or dependencies
         | 
| 93 | 
            +
            #   having no cross-platform support, pipenv may install dependencies that don't work, or not
         | 
| 94 | 
            +
            #   install all needed dependencies.
         | 
| 95 | 
            +
            #Pipfile.lock
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            # poetry
         | 
| 98 | 
            +
            #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
         | 
| 99 | 
            +
            #   This is especially recommended for binary packages to ensure reproducibility, and is more
         | 
| 100 | 
            +
            #   commonly ignored for libraries.
         | 
| 101 | 
            +
            #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
         | 
| 102 | 
            +
            #poetry.lock
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            # pdm
         | 
| 105 | 
            +
            #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
         | 
| 106 | 
            +
            #pdm.lock
         | 
| 107 | 
            +
            #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
         | 
| 108 | 
            +
            #   in version control.
         | 
| 109 | 
            +
            #   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
         | 
| 110 | 
            +
            .pdm.toml
         | 
| 111 | 
            +
            .pdm-python
         | 
| 112 | 
            +
            .pdm-build/
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
         | 
| 115 | 
            +
            __pypackages__/
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            # Celery stuff
         | 
| 118 | 
            +
            celerybeat-schedule
         | 
| 119 | 
            +
            celerybeat.pid
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            # SageMath parsed files
         | 
| 122 | 
            +
            *.sage.py
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            # Environments
         | 
| 125 | 
            +
            .env
         | 
| 126 | 
            +
            .venv*/
         | 
| 127 | 
            +
            env/
         | 
| 128 | 
            +
            venv*/
         | 
| 129 | 
            +
            ENV/
         | 
| 130 | 
            +
            env.bak/
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            # Spyder project settings
         | 
| 133 | 
            +
            .spyderproject
         | 
| 134 | 
            +
            .spyproject
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            # Rope project settings
         | 
| 137 | 
            +
            .ropeproject
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            # mkdocs documentation
         | 
| 140 | 
            +
            /site
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            # mypy
         | 
| 143 | 
            +
            .mypy_cache/
         | 
| 144 | 
            +
            .dmypy.json
         | 
| 145 | 
            +
            dmypy.json
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            # Pyre type checker
         | 
| 148 | 
            +
            .pyre/
         | 
| 149 | 
            +
             | 
| 150 | 
            +
            # pytype static type analyzer
         | 
| 151 | 
            +
            .pytype/
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            # Cython debug symbols
         | 
| 154 | 
            +
            cython_debug/
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            # PyCharm
         | 
| 157 | 
            +
            #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
         | 
| 158 | 
            +
            #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
         | 
| 159 | 
            +
            #  and can be added to the global gitignore or merged into this file.  For a more nuclear
         | 
| 160 | 
            +
            #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
         | 
| 161 | 
            +
            #.idea/
         | 
| 162 | 
            +
            .vs/
         | 
| 163 | 
            +
            .idea/
         | 
| 164 | 
            +
            .vscode/
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            stabilityai/
         | 
| 167 | 
            +
            output/
         | 
    	
        .pre-commit-config.yaml
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            default_language_version:
         | 
| 2 | 
            +
              python: python3
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            repos:
         | 
| 5 | 
            +
              - repo: https://github.com/pre-commit/pre-commit-hooks
         | 
| 6 | 
            +
                rev: v4.4.0
         | 
| 7 | 
            +
                hooks:
         | 
| 8 | 
            +
                  - id: trailing-whitespace
         | 
| 9 | 
            +
                  - id: check-ast
         | 
| 10 | 
            +
                  - id: check-merge-conflict
         | 
| 11 | 
            +
                  - id: check-yaml
         | 
| 12 | 
            +
                  - id: end-of-file-fixer
         | 
| 13 | 
            +
                  - id: trailing-whitespace
         | 
| 14 | 
            +
                    args: [--markdown-linebreak-ext=md]
         | 
| 15 | 
            +
             | 
| 16 | 
            +
              - repo: https://github.com/astral-sh/ruff-pre-commit
         | 
| 17 | 
            +
                # Ruff version.
         | 
| 18 | 
            +
                rev: v0.3.5
         | 
| 19 | 
            +
                hooks:
         | 
| 20 | 
            +
                  # Run the linter.
         | 
| 21 | 
            +
                  - id: ruff
         | 
| 22 | 
            +
                    args: [ --fix ]
         | 
| 23 | 
            +
                  # Run the formatter.
         | 
| 24 | 
            +
                  - id: ruff-format
         | 
    	
        LICENSE.md
    ADDED
    
    | @@ -0,0 +1,51 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            STABILITY AI COMMUNITY LICENSE AGREEMENT
         | 
| 2 | 
            +
            Last Updated: July 5, 2024
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            I. INTRODUCTION
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials  or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement  preserves free access to the Models for people or organizations  generating annual revenue of less than US $1,000,000 (or local currency equivalent).
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            By clicking "I Accept"  or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            II. RESEARCH & NON-COMMERCIAL USE LICENSE
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            III. COMMERCIAL USE LICENSE
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations.
         | 
| 22 | 
            +
            If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            IV. GENERAL TERMS
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
         | 
| 27 | 
            +
            a.  Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright ©  Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation.  If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified.
         | 
| 28 | 
            +
            b.  Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
         | 
| 29 | 
            +
            c.  Intellectual Property.
         | 
| 30 | 
            +
            (i) Trademark License.  No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
         | 
| 31 | 
            +
            (ii)  Ownership of Derivative Works.  As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
         | 
| 32 | 
            +
            (iii)  Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
         | 
| 33 | 
            +
            (iv)  Disputes.  If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
         | 
| 34 | 
            +
            (v)  Feedback.  From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback.
         | 
| 35 | 
            +
            d.  Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
         | 
| 36 | 
            +
            e.  Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
         | 
| 37 | 
            +
            f.  Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
         | 
| 38 | 
            +
            g.  Governing Law.  This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            V. DEFINITIONS
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            "Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
         | 
| 43 | 
            +
            "Agreement" means this Stability AI Community License Agreement.
         | 
| 44 | 
            +
            "AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
         | 
| 45 | 
            +
            "Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model.
         | 
| 46 | 
            +
            "Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
         | 
| 47 | 
            +
            "Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time.
         | 
| 48 | 
            +
            "Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
         | 
| 49 | 
            +
            "Software" means Stability AI's proprietary software made available under this Agreement now or in the future.
         | 
| 50 | 
            +
            "Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
         | 
| 51 | 
            +
            "Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,11 +1,11 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title: Stable Point | 
| 3 | 
             
            emoji: ⚡
         | 
| 4 | 
             
            colorFrom: yellow
         | 
| 5 | 
             
            colorTo: yellow
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version:  | 
| 8 | 
            -
            app_file:  | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
| 11 |  | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: Stable Point-Aware 3D
         | 
| 3 | 
             
            emoji: ⚡
         | 
| 4 | 
             
            colorFrom: yellow
         | 
| 5 | 
             
            colorTo: yellow
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 4.43.0
         | 
| 8 | 
            +
            app_file: gradio_app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
| 11 |  | 
    	
        __init__.py
    ADDED
    
    | @@ -0,0 +1,358 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import base64
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            import sys
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import comfy.model_management
         | 
| 8 | 
            +
            import folder_paths
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            import trimesh
         | 
| 12 | 
            +
            from PIL import Image
         | 
| 13 | 
            +
            from trimesh.exchange import gltf
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            sys.path.append(os.path.dirname(__file__))
         | 
| 16 | 
            +
            from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
         | 
| 17 | 
            +
            from spar3d.system import SPAR3D
         | 
| 18 | 
            +
            from spar3d.utils import foreground_crop
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            SPAR3D_CATEGORY = "SPAR3D"
         | 
| 21 | 
            +
            SPAR3D_MODEL_NAME = "stabilityai/spar3d"
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class SPAR3DLoader:
         | 
| 25 | 
            +
                CATEGORY = SPAR3D_CATEGORY
         | 
| 26 | 
            +
                FUNCTION = "load"
         | 
| 27 | 
            +
                RETURN_NAMES = ("spar3d_model",)
         | 
| 28 | 
            +
                RETURN_TYPES = ("SPAR3D_MODEL",)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                @classmethod
         | 
| 31 | 
            +
                def INPUT_TYPES(cls):
         | 
| 32 | 
            +
                    return {"required": {}}
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def load(self):
         | 
| 35 | 
            +
                    device = comfy.model_management.get_torch_device()
         | 
| 36 | 
            +
                    model = SPAR3D.from_pretrained(
         | 
| 37 | 
            +
                        SPAR3D_MODEL_NAME,
         | 
| 38 | 
            +
                        config_name="config.yaml",
         | 
| 39 | 
            +
                        weight_name="model.safetensors",
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
                    model.to(device)
         | 
| 42 | 
            +
                    model.eval()
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    return (model,)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class SPAR3DPreview:
         | 
| 48 | 
            +
                CATEGORY = SPAR3D_CATEGORY
         | 
| 49 | 
            +
                FUNCTION = "preview"
         | 
| 50 | 
            +
                OUTPUT_NODE = True
         | 
| 51 | 
            +
                RETURN_TYPES = ()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                @classmethod
         | 
| 54 | 
            +
                def INPUT_TYPES(s):
         | 
| 55 | 
            +
                    return {"required": {"mesh": ("MESH",)}}
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def preview(self, mesh):
         | 
| 58 | 
            +
                    glbs = []
         | 
| 59 | 
            +
                    for m in mesh:
         | 
| 60 | 
            +
                        scene = trimesh.Scene(m)
         | 
| 61 | 
            +
                        glb_data = gltf.export_glb(scene, include_normals=True)
         | 
| 62 | 
            +
                        glb_base64 = base64.b64encode(glb_data).decode("utf-8")
         | 
| 63 | 
            +
                        glbs.append(glb_base64)
         | 
| 64 | 
            +
                    return {"ui": {"glbs": glbs}}
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class SPAR3DSampler:
         | 
| 68 | 
            +
                CATEGORY = SPAR3D_CATEGORY
         | 
| 69 | 
            +
                FUNCTION = "predict"
         | 
| 70 | 
            +
                RETURN_NAMES = ("mesh", "pointcloud")
         | 
| 71 | 
            +
                RETURN_TYPES = ("MESH", "POINTCLOUD")
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                @classmethod
         | 
| 74 | 
            +
                def INPUT_TYPES(s):
         | 
| 75 | 
            +
                    remesh_choices = ["none"]
         | 
| 76 | 
            +
                    if TRIANGLE_REMESH_AVAILABLE:
         | 
| 77 | 
            +
                        remesh_choices.append("triangle")
         | 
| 78 | 
            +
                    if QUAD_REMESH_AVAILABLE:
         | 
| 79 | 
            +
                        remesh_choices.append("quad")
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    opt_dict = {
         | 
| 82 | 
            +
                        "mask": ("MASK",),
         | 
| 83 | 
            +
                        "pointcloud": ("POINTCLOUD",),
         | 
| 84 | 
            +
                        "target_type": (["none", "vertex", "face"],),
         | 
| 85 | 
            +
                        "target_count": (
         | 
| 86 | 
            +
                            "INT",
         | 
| 87 | 
            +
                            {"default": 1000, "min": 3, "max": 20000, "step": 1},
         | 
| 88 | 
            +
                        ),
         | 
| 89 | 
            +
                        "guidance_scale": (
         | 
| 90 | 
            +
                            "FLOAT",
         | 
| 91 | 
            +
                            {"default": 3.0, "min": 1.0, "max": 5.0, "step": 0.05},
         | 
| 92 | 
            +
                        ),
         | 
| 93 | 
            +
                        "seed": (
         | 
| 94 | 
            +
                            "INT",
         | 
| 95 | 
            +
                            {"default": 42, "min": 0, "max": 2**32 - 1, "step": 1},
         | 
| 96 | 
            +
                        ),
         | 
| 97 | 
            +
                    }
         | 
| 98 | 
            +
                    if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE:
         | 
| 99 | 
            +
                        opt_dict["remesh"] = (remesh_choices,)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    return {
         | 
| 102 | 
            +
                        "required": {
         | 
| 103 | 
            +
                            "model": ("SPAR3D_MODEL",),
         | 
| 104 | 
            +
                            "image": ("IMAGE",),
         | 
| 105 | 
            +
                            "foreground_ratio": (
         | 
| 106 | 
            +
                                "FLOAT",
         | 
| 107 | 
            +
                                {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.01},
         | 
| 108 | 
            +
                            ),
         | 
| 109 | 
            +
                            "texture_resolution": (
         | 
| 110 | 
            +
                                "INT",
         | 
| 111 | 
            +
                                {"default": 1024, "min": 512, "max": 2048, "step": 256},
         | 
| 112 | 
            +
                            ),
         | 
| 113 | 
            +
                        },
         | 
| 114 | 
            +
                        "optional": opt_dict,
         | 
| 115 | 
            +
                    }
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def predict(
         | 
| 118 | 
            +
                    s,
         | 
| 119 | 
            +
                    model,
         | 
| 120 | 
            +
                    image,
         | 
| 121 | 
            +
                    mask,
         | 
| 122 | 
            +
                    foreground_ratio,
         | 
| 123 | 
            +
                    texture_resolution,
         | 
| 124 | 
            +
                    pointcloud=None,
         | 
| 125 | 
            +
                    remesh="none",
         | 
| 126 | 
            +
                    target_type="none",
         | 
| 127 | 
            +
                    target_count=1000,
         | 
| 128 | 
            +
                    guidance_scale=3.0,
         | 
| 129 | 
            +
                    seed=42,
         | 
| 130 | 
            +
                ):
         | 
| 131 | 
            +
                    if image.shape[0] != 1:
         | 
| 132 | 
            +
                        raise ValueError("Only one image can be processed at a time")
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    vertex_count = (
         | 
| 135 | 
            +
                        -1
         | 
| 136 | 
            +
                        if target_type == "none"
         | 
| 137 | 
            +
                        else (target_count // 2 if target_type == "face" else target_count)
         | 
| 138 | 
            +
                    )
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    pil_image = Image.fromarray(
         | 
| 141 | 
            +
                        torch.clamp(torch.round(255.0 * image[0]), 0, 255)
         | 
| 142 | 
            +
                        .type(torch.uint8)
         | 
| 143 | 
            +
                        .cpu()
         | 
| 144 | 
            +
                        .numpy()
         | 
| 145 | 
            +
                    )
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    if mask is not None:
         | 
| 148 | 
            +
                        print("Using Mask")
         | 
| 149 | 
            +
                        mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype(
         | 
| 150 | 
            +
                            np.uint8
         | 
| 151 | 
            +
                        )
         | 
| 152 | 
            +
                        mask_pil = Image.fromarray(mask_np, mode="L")
         | 
| 153 | 
            +
                        pil_image.putalpha(mask_pil)
         | 
| 154 | 
            +
                    else:
         | 
| 155 | 
            +
                        if image.shape[3] != 4:
         | 
| 156 | 
            +
                            print("No mask or alpha channel detected, Converting to RGBA")
         | 
| 157 | 
            +
                            pil_image = pil_image.convert("RGBA")
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    pil_image = foreground_crop(pil_image, foreground_ratio)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    model.cfg.guidance_scale = guidance_scale
         | 
| 162 | 
            +
                    random.seed(seed)
         | 
| 163 | 
            +
                    torch.manual_seed(seed)
         | 
| 164 | 
            +
                    np.random.seed(seed)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    print(remesh)
         | 
| 167 | 
            +
                    with torch.no_grad():
         | 
| 168 | 
            +
                        with torch.autocast(device_type="cuda", dtype=torch.float16):
         | 
| 169 | 
            +
                            if not TRIANGLE_REMESH_AVAILABLE and remesh == "triangle":
         | 
| 170 | 
            +
                                raise ImportError(
         | 
| 171 | 
            +
                                    "Triangle remeshing requires gpytoolbox to be installed"
         | 
| 172 | 
            +
                                )
         | 
| 173 | 
            +
                            if not QUAD_REMESH_AVAILABLE and remesh == "quad":
         | 
| 174 | 
            +
                                raise ImportError("Quad remeshing requires pynim to be installed")
         | 
| 175 | 
            +
                            mesh, glob_dict = model.run_image(
         | 
| 176 | 
            +
                                pil_image,
         | 
| 177 | 
            +
                                bake_resolution=texture_resolution,
         | 
| 178 | 
            +
                                pointcloud=pointcloud,
         | 
| 179 | 
            +
                                remesh=remesh,
         | 
| 180 | 
            +
                                vertex_count=vertex_count,
         | 
| 181 | 
            +
                            )
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    if mesh.vertices.shape[0] == 0:
         | 
| 184 | 
            +
                        raise ValueError("No subject detected in the image")
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    return (
         | 
| 187 | 
            +
                        [mesh],
         | 
| 188 | 
            +
                        glob_dict["pointcloud"].view(-1).detach().cpu().numpy().tolist(),
         | 
| 189 | 
            +
                    )
         | 
| 190 | 
            +
             | 
| 191 | 
            +
             | 
| 192 | 
            +
            class SPAR3DSave:
         | 
| 193 | 
            +
                CATEGORY = SPAR3D_CATEGORY
         | 
| 194 | 
            +
                FUNCTION = "save"
         | 
| 195 | 
            +
                OUTPUT_NODE = True
         | 
| 196 | 
            +
                RETURN_TYPES = ()
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                @classmethod
         | 
| 199 | 
            +
                def INPUT_TYPES(s):
         | 
| 200 | 
            +
                    return {
         | 
| 201 | 
            +
                        "required": {
         | 
| 202 | 
            +
                            "mesh": ("MESH",),
         | 
| 203 | 
            +
                            "filename_prefix": ("STRING", {"default": "SPAR3D"}),
         | 
| 204 | 
            +
                        }
         | 
| 205 | 
            +
                    }
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                def __init__(self):
         | 
| 208 | 
            +
                    self.type = "output"
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                def save(self, mesh, filename_prefix):
         | 
| 211 | 
            +
                    output_dir = folder_paths.get_output_directory()
         | 
| 212 | 
            +
                    glbs = []
         | 
| 213 | 
            +
                    for idx, m in enumerate(mesh):
         | 
| 214 | 
            +
                        scene = trimesh.Scene(m)
         | 
| 215 | 
            +
                        glb_data = gltf.export_glb(scene, include_normals=True)
         | 
| 216 | 
            +
                        logging.info(f"Generated GLB model with {len(glb_data)} bytes")
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        full_output_folder, filename, counter, subfolder, filename_prefix = (
         | 
| 219 | 
            +
                            folder_paths.get_save_image_path(filename_prefix, output_dir)
         | 
| 220 | 
            +
                        )
         | 
| 221 | 
            +
                        filename = filename.replace("%batch_num%", str(idx))
         | 
| 222 | 
            +
                        out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb")
         | 
| 223 | 
            +
                        with open(out_path, "wb") as f:
         | 
| 224 | 
            +
                            f.write(glb_data)
         | 
| 225 | 
            +
                        glbs.append(base64.b64encode(glb_data).decode("utf-8"))
         | 
| 226 | 
            +
                    return {"ui": {"glbs": glbs}}
         | 
| 227 | 
            +
             | 
| 228 | 
            +
             | 
| 229 | 
            +
            class SPAR3DPointCloudLoader:
         | 
| 230 | 
            +
                CATEGORY = SPAR3D_CATEGORY
         | 
| 231 | 
            +
                FUNCTION = "load_pointcloud"
         | 
| 232 | 
            +
                RETURN_TYPES = ("POINTCLOUD",)
         | 
| 233 | 
            +
                RETURN_NAMES = ("pointcloud",)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                @classmethod
         | 
| 236 | 
            +
                def INPUT_TYPES(cls):
         | 
| 237 | 
            +
                    return {
         | 
| 238 | 
            +
                        "required": {
         | 
| 239 | 
            +
                            "file": ("STRING", {"default": None}),
         | 
| 240 | 
            +
                        }
         | 
| 241 | 
            +
                    }
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                def load_pointcloud(self, file):
         | 
| 244 | 
            +
                    if file is None or file == "":
         | 
| 245 | 
            +
                        return (None,)
         | 
| 246 | 
            +
                    # Load the mesh using trimesh
         | 
| 247 | 
            +
                    mesh = trimesh.load(file)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    # Extract vertices and colors
         | 
| 250 | 
            +
                    vertices = mesh.vertices
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    # Get vertex colors, defaulting to white if none exist
         | 
| 253 | 
            +
                    if mesh.visual.vertex_colors is not None:
         | 
| 254 | 
            +
                        colors = (
         | 
| 255 | 
            +
                            mesh.visual.vertex_colors[:, :3] / 255.0
         | 
| 256 | 
            +
                        )  # Convert 0-255 to 0-1 range
         | 
| 257 | 
            +
                    else:
         | 
| 258 | 
            +
                        colors = np.ones((len(vertices), 3))
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    # Interleave XYZ and RGB values
         | 
| 261 | 
            +
                    point_cloud = []
         | 
| 262 | 
            +
                    for vertex, color in zip(vertices, colors):
         | 
| 263 | 
            +
                        point_cloud.extend(
         | 
| 264 | 
            +
                            [
         | 
| 265 | 
            +
                                float(vertex[0]),
         | 
| 266 | 
            +
                                float(vertex[1]),
         | 
| 267 | 
            +
                                float(vertex[2]),
         | 
| 268 | 
            +
                                float(color[0]),
         | 
| 269 | 
            +
                                float(color[1]),
         | 
| 270 | 
            +
                                float(color[2]),
         | 
| 271 | 
            +
                            ]
         | 
| 272 | 
            +
                        )
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    return (point_cloud,)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
             | 
| 277 | 
            +
            class SPAR3DPointCloudSaver:
         | 
| 278 | 
            +
                CATEGORY = SPAR3D_CATEGORY
         | 
| 279 | 
            +
                FUNCTION = "save_pointcloud"
         | 
| 280 | 
            +
                OUTPUT_NODE = True
         | 
| 281 | 
            +
                RETURN_TYPES = ()
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                @classmethod
         | 
| 284 | 
            +
                def INPUT_TYPES(s):
         | 
| 285 | 
            +
                    return {
         | 
| 286 | 
            +
                        "required": {
         | 
| 287 | 
            +
                            "pointcloud": ("POINTCLOUD",),
         | 
| 288 | 
            +
                            "filename_prefix": ("STRING", {"default": "SPAR3D"}),
         | 
| 289 | 
            +
                        }
         | 
| 290 | 
            +
                    }
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                def save_pointcloud(self, pointcloud, filename_prefix):
         | 
| 293 | 
            +
                    if pointcloud is None:
         | 
| 294 | 
            +
                        return {"ui": {"text": "No point cloud data to save"}}
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    # Reshape the flat list into points with XYZ and RGB
         | 
| 297 | 
            +
                    points = np.array(pointcloud).reshape(-1, 6)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    # Create vertex array for PLY
         | 
| 300 | 
            +
                    vertex_array = np.zeros(
         | 
| 301 | 
            +
                        len(points),
         | 
| 302 | 
            +
                        dtype=[
         | 
| 303 | 
            +
                            ("x", "f4"),
         | 
| 304 | 
            +
                            ("y", "f4"),
         | 
| 305 | 
            +
                            ("z", "f4"),
         | 
| 306 | 
            +
                            ("red", "u1"),
         | 
| 307 | 
            +
                            ("green", "u1"),
         | 
| 308 | 
            +
                            ("blue", "u1"),
         | 
| 309 | 
            +
                        ],
         | 
| 310 | 
            +
                    )
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    # Fill vertex array
         | 
| 313 | 
            +
                    vertex_array["x"] = points[:, 0]
         | 
| 314 | 
            +
                    vertex_array["y"] = points[:, 1]
         | 
| 315 | 
            +
                    vertex_array["z"] = points[:, 2]
         | 
| 316 | 
            +
                    # Convert RGB from 0-1 to 0-255 range
         | 
| 317 | 
            +
                    vertex_array["red"] = (points[:, 3] * 255).astype(np.uint8)
         | 
| 318 | 
            +
                    vertex_array["green"] = (points[:, 4] * 255).astype(np.uint8)
         | 
| 319 | 
            +
                    vertex_array["blue"] = (points[:, 5] * 255).astype(np.uint8)
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    # Create PLY object
         | 
| 322 | 
            +
                    ply_data = trimesh.PointCloud(
         | 
| 323 | 
            +
                        vertices=points[:, :3], colors=points[:, 3:] * 255
         | 
| 324 | 
            +
                    )
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    # Save to file
         | 
| 327 | 
            +
                    output_dir = folder_paths.get_output_directory()
         | 
| 328 | 
            +
                    full_output_folder, filename, counter, subfolder, filename_prefix = (
         | 
| 329 | 
            +
                        folder_paths.get_save_image_path(filename_prefix, output_dir)
         | 
| 330 | 
            +
                    )
         | 
| 331 | 
            +
                    out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.ply")
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    ply_data.export(out_path)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    return {"ui": {"text": f"Saved point cloud to {out_path}"}}
         | 
| 336 | 
            +
             | 
| 337 | 
            +
             | 
| 338 | 
            +
            NODE_DISPLAY_NAME_MAPPINGS = {
         | 
| 339 | 
            +
                "SPAR3DLoader": "SPAR3D Loader",
         | 
| 340 | 
            +
                "SPAR3DPreview": "SPAR3D Preview",
         | 
| 341 | 
            +
                "SPAR3DSampler": "SPAR3D Sampler",
         | 
| 342 | 
            +
                "SPAR3DSave": "SPAR3D Save",
         | 
| 343 | 
            +
                "SPAR3DPointCloudLoader": "SPAR3D Point Cloud Loader",
         | 
| 344 | 
            +
                "SPAR3DPointCloudSaver": "SPAR3D Point Cloud Saver",
         | 
| 345 | 
            +
            }
         | 
| 346 | 
            +
             | 
| 347 | 
            +
            NODE_CLASS_MAPPINGS = {
         | 
| 348 | 
            +
                "SPAR3DLoader": SPAR3DLoader,
         | 
| 349 | 
            +
                "SPAR3DPreview": SPAR3DPreview,
         | 
| 350 | 
            +
                "SPAR3DSampler": SPAR3DSampler,
         | 
| 351 | 
            +
                "SPAR3DSave": SPAR3DSave,
         | 
| 352 | 
            +
                "SPAR3DPointCloudLoader": SPAR3DPointCloudLoader,
         | 
| 353 | 
            +
                "SPAR3DPointCloudSaver": SPAR3DPointCloudSaver,
         | 
| 354 | 
            +
            }
         | 
| 355 | 
            +
             | 
| 356 | 
            +
            WEB_DIRECTORY = "./comfyui"
         | 
| 357 | 
            +
             | 
| 358 | 
            +
            __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
         | 
    	
        demo_files/comp.gif
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/examples/bird.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/examples/castle.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/examples/chest.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/examples/doll.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/examples/excavator.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/examples/fish.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/examples/horse-statue.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/examples/penguin.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/examples/pot.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/examples/raccoon_wizard.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/examples/stylized-rocks.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/hdri/abandoned_tiled_room_1k.hdr
    ADDED
    
    | Binary file (478 kB). View file | 
|  | 
    	
        demo_files/hdri/metro_noord_1k.hdr
    ADDED
    
    | Binary file (467 kB). View file | 
|  | 
    	
        demo_files/hdri/neon_photostudio_1k.hdr
    ADDED
    
    | Binary file (438 kB). View file | 
|  | 
    	
        demo_files/hdri/peppermint_powerplant_1k.hdr
    ADDED
    
    | Binary file (473 kB). View file | 
|  | 
    	
        demo_files/hdri/rainforest_trail_1k.hdr
    ADDED
    
    | Binary file (512 kB). View file | 
|  | 
    	
        demo_files/hdri/studio_small_08_1k.hdr
    ADDED
    
    | Binary file (412 kB). View file | 
|  | 
    	
        demo_files/hdri/urban_alley_01_1k.hdr
    ADDED
    
    | Binary file (458 kB). View file | 
|  | 
    	
        demo_files/turntable.gif
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        demo_files/workflows/spar3d_example.json
    ADDED
    
    | @@ -0,0 +1,263 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "last_node_id": 17,
         | 
| 3 | 
            +
              "last_link_id": 18,
         | 
| 4 | 
            +
              "nodes": [
         | 
| 5 | 
            +
                {
         | 
| 6 | 
            +
                  "id": 10,
         | 
| 7 | 
            +
                  "type": "SPAR3DLoader",
         | 
| 8 | 
            +
                  "pos": [
         | 
| 9 | 
            +
                    52.92446517944336,
         | 
| 10 | 
            +
                    394.328369140625
         | 
| 11 | 
            +
                  ],
         | 
| 12 | 
            +
                  "size": [
         | 
| 13 | 
            +
                    210,
         | 
| 14 | 
            +
                    26
         | 
| 15 | 
            +
                  ],
         | 
| 16 | 
            +
                  "flags": {},
         | 
| 17 | 
            +
                  "order": 0,
         | 
| 18 | 
            +
                  "mode": 0,
         | 
| 19 | 
            +
                  "inputs": [],
         | 
| 20 | 
            +
                  "outputs": [
         | 
| 21 | 
            +
                    {
         | 
| 22 | 
            +
                      "name": "spar3d_model",
         | 
| 23 | 
            +
                      "type": "SPAR3D_MODEL",
         | 
| 24 | 
            +
                      "links": [
         | 
| 25 | 
            +
                        10
         | 
| 26 | 
            +
                      ],
         | 
| 27 | 
            +
                      "slot_index": 0
         | 
| 28 | 
            +
                    }
         | 
| 29 | 
            +
                  ],
         | 
| 30 | 
            +
                  "properties": {
         | 
| 31 | 
            +
                    "Node name for S&R": "SPAR3DLoader"
         | 
| 32 | 
            +
                  },
         | 
| 33 | 
            +
                  "widgets_values": []
         | 
| 34 | 
            +
                },
         | 
| 35 | 
            +
                {
         | 
| 36 | 
            +
                  "id": 13,
         | 
| 37 | 
            +
                  "type": "LoadImage",
         | 
| 38 | 
            +
                  "pos": [
         | 
| 39 | 
            +
                    -43.437347412109375,
         | 
| 40 | 
            +
                    482.89678955078125
         | 
| 41 | 
            +
                  ],
         | 
| 42 | 
            +
                  "size": [
         | 
| 43 | 
            +
                    315,
         | 
| 44 | 
            +
                    314
         | 
| 45 | 
            +
                  ],
         | 
| 46 | 
            +
                  "flags": {},
         | 
| 47 | 
            +
                  "order": 1,
         | 
| 48 | 
            +
                  "mode": 0,
         | 
| 49 | 
            +
                  "inputs": [],
         | 
| 50 | 
            +
                  "outputs": [
         | 
| 51 | 
            +
                    {
         | 
| 52 | 
            +
                      "name": "IMAGE",
         | 
| 53 | 
            +
                      "type": "IMAGE",
         | 
| 54 | 
            +
                      "links": [
         | 
| 55 | 
            +
                        11
         | 
| 56 | 
            +
                      ],
         | 
| 57 | 
            +
                      "slot_index": 0
         | 
| 58 | 
            +
                    },
         | 
| 59 | 
            +
                    {
         | 
| 60 | 
            +
                      "name": "MASK",
         | 
| 61 | 
            +
                      "type": "MASK",
         | 
| 62 | 
            +
                      "links": [
         | 
| 63 | 
            +
                        16
         | 
| 64 | 
            +
                      ],
         | 
| 65 | 
            +
                      "slot_index": 1
         | 
| 66 | 
            +
                    }
         | 
| 67 | 
            +
                  ],
         | 
| 68 | 
            +
                  "properties": {
         | 
| 69 | 
            +
                    "Node name for S&R": "LoadImage"
         | 
| 70 | 
            +
                  },
         | 
| 71 | 
            +
                  "widgets_values": [
         | 
| 72 | 
            +
                    "cat1.png",
         | 
| 73 | 
            +
                    "image"
         | 
| 74 | 
            +
                  ]
         | 
| 75 | 
            +
                },
         | 
| 76 | 
            +
                {
         | 
| 77 | 
            +
                  "id": 16,
         | 
| 78 | 
            +
                  "type": "InvertMask",
         | 
| 79 | 
            +
                  "pos": [
         | 
| 80 | 
            +
                    377.1180419921875,
         | 
| 81 | 
            +
                    605.384765625
         | 
| 82 | 
            +
                  ],
         | 
| 83 | 
            +
                  "size": [
         | 
| 84 | 
            +
                    210,
         | 
| 85 | 
            +
                    26
         | 
| 86 | 
            +
                  ],
         | 
| 87 | 
            +
                  "flags": {},
         | 
| 88 | 
            +
                  "order": 2,
         | 
| 89 | 
            +
                  "mode": 0,
         | 
| 90 | 
            +
                  "inputs": [
         | 
| 91 | 
            +
                    {
         | 
| 92 | 
            +
                      "name": "mask",
         | 
| 93 | 
            +
                      "type": "MASK",
         | 
| 94 | 
            +
                      "link": 16
         | 
| 95 | 
            +
                    }
         | 
| 96 | 
            +
                  ],
         | 
| 97 | 
            +
                  "outputs": [
         | 
| 98 | 
            +
                    {
         | 
| 99 | 
            +
                      "name": "MASK",
         | 
| 100 | 
            +
                      "type": "MASK",
         | 
| 101 | 
            +
                      "links": [
         | 
| 102 | 
            +
                        17
         | 
| 103 | 
            +
                      ],
         | 
| 104 | 
            +
                      "slot_index": 0
         | 
| 105 | 
            +
                    }
         | 
| 106 | 
            +
                  ],
         | 
| 107 | 
            +
                  "properties": {
         | 
| 108 | 
            +
                    "Node name for S&R": "InvertMask"
         | 
| 109 | 
            +
                  },
         | 
| 110 | 
            +
                  "widgets_values": []
         | 
| 111 | 
            +
                },
         | 
| 112 | 
            +
                {
         | 
| 113 | 
            +
                  "id": 17,
         | 
| 114 | 
            +
                  "type": "SPAR3DSave",
         | 
| 115 | 
            +
                  "pos": [
         | 
| 116 | 
            +
                    1133.669921875,
         | 
| 117 | 
            +
                    439.6551513671875
         | 
| 118 | 
            +
                  ],
         | 
| 119 | 
            +
                  "size": [
         | 
| 120 | 
            +
                    315,
         | 
| 121 | 
            +
                    58
         | 
| 122 | 
            +
                  ],
         | 
| 123 | 
            +
                  "flags": {},
         | 
| 124 | 
            +
                  "order": 4,
         | 
| 125 | 
            +
                  "mode": 0,
         | 
| 126 | 
            +
                  "inputs": [
         | 
| 127 | 
            +
                    {
         | 
| 128 | 
            +
                      "name": "mesh",
         | 
| 129 | 
            +
                      "type": "MESH",
         | 
| 130 | 
            +
                      "link": 18
         | 
| 131 | 
            +
                    }
         | 
| 132 | 
            +
                  ],
         | 
| 133 | 
            +
                  "outputs": [],
         | 
| 134 | 
            +
                  "properties": {
         | 
| 135 | 
            +
                    "Node name for S&R": "SPAR3DSave"
         | 
| 136 | 
            +
                  },
         | 
| 137 | 
            +
                  "widgets_values": [
         | 
| 138 | 
            +
                    "SPAR3D"
         | 
| 139 | 
            +
                  ]
         | 
| 140 | 
            +
                },
         | 
| 141 | 
            +
                {
         | 
| 142 | 
            +
                  "id": 11,
         | 
| 143 | 
            +
                  "type": "SPAR3DSampler",
         | 
| 144 | 
            +
                  "pos": [
         | 
| 145 | 
            +
                    673.0637817382812,
         | 
| 146 | 
            +
                    441.2229309082031
         | 
| 147 | 
            +
                  ],
         | 
| 148 | 
            +
                  "size": [
         | 
| 149 | 
            +
                    315,
         | 
| 150 | 
            +
                    286
         | 
| 151 | 
            +
                  ],
         | 
| 152 | 
            +
                  "flags": {},
         | 
| 153 | 
            +
                  "order": 3,
         | 
| 154 | 
            +
                  "mode": 0,
         | 
| 155 | 
            +
                  "inputs": [
         | 
| 156 | 
            +
                    {
         | 
| 157 | 
            +
                      "name": "model",
         | 
| 158 | 
            +
                      "type": "SPAR3D_MODEL",
         | 
| 159 | 
            +
                      "link": 10
         | 
| 160 | 
            +
                    },
         | 
| 161 | 
            +
                    {
         | 
| 162 | 
            +
                      "name": "image",
         | 
| 163 | 
            +
                      "type": "IMAGE",
         | 
| 164 | 
            +
                      "link": 11
         | 
| 165 | 
            +
                    },
         | 
| 166 | 
            +
                    {
         | 
| 167 | 
            +
                      "name": "mask",
         | 
| 168 | 
            +
                      "type": "MASK",
         | 
| 169 | 
            +
                      "link": 17,
         | 
| 170 | 
            +
                      "shape": 7
         | 
| 171 | 
            +
                    },
         | 
| 172 | 
            +
                    {
         | 
| 173 | 
            +
                      "name": "pointcloud",
         | 
| 174 | 
            +
                      "type": "POINTCLOUD",
         | 
| 175 | 
            +
                      "link": null,
         | 
| 176 | 
            +
                      "shape": 7
         | 
| 177 | 
            +
                    }
         | 
| 178 | 
            +
                  ],
         | 
| 179 | 
            +
                  "outputs": [
         | 
| 180 | 
            +
                    {
         | 
| 181 | 
            +
                      "name": "mesh",
         | 
| 182 | 
            +
                      "type": "MESH",
         | 
| 183 | 
            +
                      "links": [
         | 
| 184 | 
            +
                        18
         | 
| 185 | 
            +
                      ],
         | 
| 186 | 
            +
                      "slot_index": 0
         | 
| 187 | 
            +
                    },
         | 
| 188 | 
            +
                    {
         | 
| 189 | 
            +
                      "name": "pointcloud",
         | 
| 190 | 
            +
                      "type": "POINTCLOUD",
         | 
| 191 | 
            +
                      "links": null
         | 
| 192 | 
            +
                    }
         | 
| 193 | 
            +
                  ],
         | 
| 194 | 
            +
                  "properties": {
         | 
| 195 | 
            +
                    "Node name for S&R": "SPAR3DSampler"
         | 
| 196 | 
            +
                  },
         | 
| 197 | 
            +
                  "widgets_values": [
         | 
| 198 | 
            +
                    1.3,
         | 
| 199 | 
            +
                    1024,
         | 
| 200 | 
            +
                    "none",
         | 
| 201 | 
            +
                    1000,
         | 
| 202 | 
            +
                    3,
         | 
| 203 | 
            +
                    3727502160,
         | 
| 204 | 
            +
                    "randomize",
         | 
| 205 | 
            +
                    "none"
         | 
| 206 | 
            +
                  ]
         | 
| 207 | 
            +
                }
         | 
| 208 | 
            +
              ],
         | 
| 209 | 
            +
              "links": [
         | 
| 210 | 
            +
                [
         | 
| 211 | 
            +
                  10,
         | 
| 212 | 
            +
                  10,
         | 
| 213 | 
            +
                  0,
         | 
| 214 | 
            +
                  11,
         | 
| 215 | 
            +
                  0,
         | 
| 216 | 
            +
                  "SPAR3D_MODEL"
         | 
| 217 | 
            +
                ],
         | 
| 218 | 
            +
                [
         | 
| 219 | 
            +
                  11,
         | 
| 220 | 
            +
                  13,
         | 
| 221 | 
            +
                  0,
         | 
| 222 | 
            +
                  11,
         | 
| 223 | 
            +
                  1,
         | 
| 224 | 
            +
                  "IMAGE"
         | 
| 225 | 
            +
                ],
         | 
| 226 | 
            +
                [
         | 
| 227 | 
            +
                  16,
         | 
| 228 | 
            +
                  13,
         | 
| 229 | 
            +
                  1,
         | 
| 230 | 
            +
                  16,
         | 
| 231 | 
            +
                  0,
         | 
| 232 | 
            +
                  "MASK"
         | 
| 233 | 
            +
                ],
         | 
| 234 | 
            +
                [
         | 
| 235 | 
            +
                  17,
         | 
| 236 | 
            +
                  16,
         | 
| 237 | 
            +
                  0,
         | 
| 238 | 
            +
                  11,
         | 
| 239 | 
            +
                  2,
         | 
| 240 | 
            +
                  "MASK"
         | 
| 241 | 
            +
                ],
         | 
| 242 | 
            +
                [
         | 
| 243 | 
            +
                  18,
         | 
| 244 | 
            +
                  11,
         | 
| 245 | 
            +
                  0,
         | 
| 246 | 
            +
                  17,
         | 
| 247 | 
            +
                  0,
         | 
| 248 | 
            +
                  "MESH"
         | 
| 249 | 
            +
                ]
         | 
| 250 | 
            +
              ],
         | 
| 251 | 
            +
              "groups": [],
         | 
| 252 | 
            +
              "config": {},
         | 
| 253 | 
            +
              "extra": {
         | 
| 254 | 
            +
                "ds": {
         | 
| 255 | 
            +
                  "scale": 0.953502721998243,
         | 
| 256 | 
            +
                  "offset": [
         | 
| 257 | 
            +
                    266.21995970220667,
         | 
| 258 | 
            +
                    116.75398112171928
         | 
| 259 | 
            +
                  ]
         | 
| 260 | 
            +
                }
         | 
| 261 | 
            +
              },
         | 
| 262 | 
            +
              "version": 0.4
         | 
| 263 | 
            +
            }
         | 
    	
        gradio_app.py
    ADDED
    
    | @@ -0,0 +1,792 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            os.system("pip install ./texture_baker/ ./uv_unwrapper/")
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
            import tempfile
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            from contextlib import nullcontext
         | 
| 9 | 
            +
            from functools import lru_cache
         | 
| 10 | 
            +
            from typing import Any
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import gradio as gr
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import trimesh
         | 
| 16 | 
            +
            from gradio_litmodel3d import LitModel3D
         | 
| 17 | 
            +
            from gradio_pointcloudeditor import PointCloudEditor
         | 
| 18 | 
            +
            from PIL import Image
         | 
| 19 | 
            +
            from transparent_background import Remover
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import spar3d.utils as spar3d_utils
         | 
| 22 | 
            +
            from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
         | 
| 23 | 
            +
            from spar3d.system import SPAR3D
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio")
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            bg_remover = Remover()  # default setting
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            COND_WIDTH = 512
         | 
| 30 | 
            +
            COND_HEIGHT = 512
         | 
| 31 | 
            +
            COND_DISTANCE = 2.2
         | 
| 32 | 
            +
            COND_FOVY = 0.591627
         | 
| 33 | 
            +
            BACKGROUND_COLOR = [0.5, 0.5, 0.5]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            # Cached. Doesn't change
         | 
| 36 | 
            +
            c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
         | 
| 37 | 
            +
            intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
         | 
| 38 | 
            +
                COND_FOVY, COND_HEIGHT, COND_WIDTH
         | 
| 39 | 
            +
            )
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            generated_files = []
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            # Delete previous gradio temp dir folder
         | 
| 44 | 
            +
            if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
         | 
| 45 | 
            +
                print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}")
         | 
| 46 | 
            +
                import shutil
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                shutil.rmtree(os.environ["GRADIO_TEMP_DIR"])
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            device = spar3d_utils.get_device()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            model = SPAR3D.from_pretrained(
         | 
| 53 | 
            +
                "stabilityai/stable-point-aware-3d",
         | 
| 54 | 
            +
                config_name="config.yaml",
         | 
| 55 | 
            +
                weight_name="model.safetensors",
         | 
| 56 | 
            +
            )
         | 
| 57 | 
            +
            model.eval()
         | 
| 58 | 
            +
            model = model.to(device)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            example_files = [
         | 
| 61 | 
            +
                os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
         | 
| 62 | 
            +
            ]
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def forward_model(
         | 
| 66 | 
            +
                batch,
         | 
| 67 | 
            +
                system,
         | 
| 68 | 
            +
                guidance_scale=3.0,
         | 
| 69 | 
            +
                seed=0,
         | 
| 70 | 
            +
                device="cuda",
         | 
| 71 | 
            +
                remesh_option="none",
         | 
| 72 | 
            +
                vertex_count=-1,
         | 
| 73 | 
            +
                texture_resolution=1024,
         | 
| 74 | 
            +
            ):
         | 
| 75 | 
            +
                batch_size = batch["rgb_cond"].shape[0]
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                # prepare the condition for point cloud generation
         | 
| 78 | 
            +
                # set seed
         | 
| 79 | 
            +
                random.seed(seed)
         | 
| 80 | 
            +
                torch.manual_seed(seed)
         | 
| 81 | 
            +
                np.random.seed(seed)
         | 
| 82 | 
            +
                cond_tokens = system.forward_pdiff_cond(batch)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                if "pc_cond" not in batch:
         | 
| 85 | 
            +
                    sample_iter = system.sampler.sample_batch_progressive(
         | 
| 86 | 
            +
                        batch_size,
         | 
| 87 | 
            +
                        cond_tokens,
         | 
| 88 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 89 | 
            +
                        device=device,
         | 
| 90 | 
            +
                    )
         | 
| 91 | 
            +
                    for x in sample_iter:
         | 
| 92 | 
            +
                        samples = x["xstart"]
         | 
| 93 | 
            +
                    batch["pc_cond"] = samples.permute(0, 2, 1).float()
         | 
| 94 | 
            +
                    batch["pc_cond"] = spar3d_utils.normalize_pc_bbox(batch["pc_cond"])
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                # subsample to the 512 points
         | 
| 97 | 
            +
                batch["pc_cond"] = batch["pc_cond"][
         | 
| 98 | 
            +
                    :, torch.randperm(batch["pc_cond"].shape[1])[:512]
         | 
| 99 | 
            +
                ]
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                # get the point cloud
         | 
| 102 | 
            +
                xyz = batch["pc_cond"][0, :, :3].cpu().numpy()
         | 
| 103 | 
            +
                color_rgb = (batch["pc_cond"][0, :, 3:6] * 255).cpu().numpy().astype(np.uint8)
         | 
| 104 | 
            +
                pc_rgb_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                # forward for the final mesh
         | 
| 107 | 
            +
                trimesh_mesh, _glob_dict = model.generate_mesh(
         | 
| 108 | 
            +
                    batch, texture_resolution, remesh=remesh_option, vertex_count=vertex_count
         | 
| 109 | 
            +
                )
         | 
| 110 | 
            +
                trimesh_mesh = trimesh_mesh[0]
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                return trimesh_mesh, pc_rgb_trimesh
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def run_model(
         | 
| 116 | 
            +
                input_image,
         | 
| 117 | 
            +
                guidance_scale,
         | 
| 118 | 
            +
                random_seed,
         | 
| 119 | 
            +
                pc_cond,
         | 
| 120 | 
            +
                remesh_option,
         | 
| 121 | 
            +
                vertex_count,
         | 
| 122 | 
            +
                texture_resolution,
         | 
| 123 | 
            +
            ):
         | 
| 124 | 
            +
                start = time.time()
         | 
| 125 | 
            +
                with torch.no_grad():
         | 
| 126 | 
            +
                    with (
         | 
| 127 | 
            +
                        torch.autocast(device_type=device, dtype=torch.float16)
         | 
| 128 | 
            +
                        if "cuda" in device
         | 
| 129 | 
            +
                        else nullcontext()
         | 
| 130 | 
            +
                    ):
         | 
| 131 | 
            +
                        model_batch = create_batch(input_image)
         | 
| 132 | 
            +
                        model_batch = {k: v.to(device) for k, v in model_batch.items()}
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                        if pc_cond is not None:
         | 
| 135 | 
            +
                            # Check if pc_cond is a list
         | 
| 136 | 
            +
                            if isinstance(pc_cond, list):
         | 
| 137 | 
            +
                                cond_tensor = torch.tensor(pc_cond).float().cuda().view(-1, 6)
         | 
| 138 | 
            +
                                xyz = cond_tensor[:, :3]
         | 
| 139 | 
            +
                                color_rgb = cond_tensor[:, 3:]
         | 
| 140 | 
            +
                            elif isinstance(pc_cond, dict):
         | 
| 141 | 
            +
                                xyz = torch.tensor(pc_cond["positions"]).float().cuda()
         | 
| 142 | 
            +
                                color_rgb = torch.tensor(pc_cond["colors"]).float().cuda()
         | 
| 143 | 
            +
                            else:
         | 
| 144 | 
            +
                                xyz = torch.tensor(pc_cond.vertices).float().cuda()
         | 
| 145 | 
            +
                                color_rgb = (
         | 
| 146 | 
            +
                                    torch.tensor(pc_cond.colors[:, :3]).float().cuda() / 255.0
         | 
| 147 | 
            +
                                )
         | 
| 148 | 
            +
                            model_batch["pc_cond"] = torch.cat([xyz, color_rgb], dim=-1).unsqueeze(
         | 
| 149 | 
            +
                                0
         | 
| 150 | 
            +
                            )
         | 
| 151 | 
            +
                            # sub-sample the point cloud to the target number of points
         | 
| 152 | 
            +
                            if model_batch["pc_cond"].shape[1] > 512:
         | 
| 153 | 
            +
                                idx = torch.randperm(model_batch["pc_cond"].shape[1])[:512]
         | 
| 154 | 
            +
                                model_batch["pc_cond"] = model_batch["pc_cond"][:, idx]
         | 
| 155 | 
            +
                            elif model_batch["pc_cond"].shape[1] < 512:
         | 
| 156 | 
            +
                                num_points = model_batch["pc_cond"].shape[1]
         | 
| 157 | 
            +
                                gr.Warning(
         | 
| 158 | 
            +
                                    f"The uploaded point cloud should have at least 512 points. This point cloud only has {num_points}. Results may be worse."
         | 
| 159 | 
            +
                                )
         | 
| 160 | 
            +
                                pad = 512 - num_points
         | 
| 161 | 
            +
                                sampled_idx = torch.randint(
         | 
| 162 | 
            +
                                    0, model_batch["pc_cond"].shape[1], (pad,)
         | 
| 163 | 
            +
                                )
         | 
| 164 | 
            +
                                model_batch["pc_cond"] = torch.cat(
         | 
| 165 | 
            +
                                    [
         | 
| 166 | 
            +
                                        model_batch["pc_cond"],
         | 
| 167 | 
            +
                                        model_batch["pc_cond"][:, sampled_idx],
         | 
| 168 | 
            +
                                    ],
         | 
| 169 | 
            +
                                    dim=1,
         | 
| 170 | 
            +
                                )
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                        trimesh_mesh, trimesh_pc = forward_model(
         | 
| 173 | 
            +
                            model_batch,
         | 
| 174 | 
            +
                            model,
         | 
| 175 | 
            +
                            guidance_scale=guidance_scale,
         | 
| 176 | 
            +
                            seed=random_seed,
         | 
| 177 | 
            +
                            device="cuda",
         | 
| 178 | 
            +
                            remesh_option=remesh_option.lower(),
         | 
| 179 | 
            +
                            vertex_count=vertex_count,
         | 
| 180 | 
            +
                            texture_resolution=texture_resolution,
         | 
| 181 | 
            +
                        )
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                # Create new tmp file
         | 
| 184 | 
            +
                temp_dir = tempfile.mkdtemp()
         | 
| 185 | 
            +
                tmp_file = os.path.join(temp_dir, "mesh.glb")
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                trimesh_mesh.export(tmp_file, file_type="glb", include_normals=True)
         | 
| 188 | 
            +
                generated_files.append(tmp_file)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                tmp_file_pc = os.path.join(temp_dir, "points.ply")
         | 
| 191 | 
            +
                trimesh_pc.export(tmp_file_pc)
         | 
| 192 | 
            +
                generated_files.append(tmp_file_pc)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                print("Generation took:", time.time() - start, "s")
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                return tmp_file, tmp_file_pc, trimesh_pc
         | 
| 197 | 
            +
             | 
| 198 | 
            +
             | 
| 199 | 
            +
            def create_batch(input_image: Image) -> dict[str, Any]:
         | 
| 200 | 
            +
                img_cond = (
         | 
| 201 | 
            +
                    torch.from_numpy(
         | 
| 202 | 
            +
                        np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
         | 
| 203 | 
            +
                        / 255.0
         | 
| 204 | 
            +
                    )
         | 
| 205 | 
            +
                    .float()
         | 
| 206 | 
            +
                    .clip(0, 1)
         | 
| 207 | 
            +
                )
         | 
| 208 | 
            +
                mask_cond = img_cond[:, :, -1:]
         | 
| 209 | 
            +
                rgb_cond = torch.lerp(
         | 
| 210 | 
            +
                    torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
         | 
| 211 | 
            +
                )
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                batch_elem = {
         | 
| 214 | 
            +
                    "rgb_cond": rgb_cond,
         | 
| 215 | 
            +
                    "mask_cond": mask_cond,
         | 
| 216 | 
            +
                    "c2w_cond": c2w_cond.unsqueeze(0),
         | 
| 217 | 
            +
                    "intrinsic_cond": intrinsic.unsqueeze(0),
         | 
| 218 | 
            +
                    "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
         | 
| 219 | 
            +
                }
         | 
| 220 | 
            +
                # Add batch dim
         | 
| 221 | 
            +
                batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
         | 
| 222 | 
            +
                return batched
         | 
| 223 | 
            +
             | 
| 224 | 
            +
             | 
| 225 | 
            +
            @lru_cache
         | 
| 226 | 
            +
            def checkerboard(squares: int, size: int, min_value: float = 0.5):
         | 
| 227 | 
            +
                base = np.zeros((squares, squares)) + min_value
         | 
| 228 | 
            +
                base[1::2, ::2] = 1
         | 
| 229 | 
            +
                base[::2, 1::2] = 1
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                repeat_mult = size // squares
         | 
| 232 | 
            +
                return (
         | 
| 233 | 
            +
                    base.repeat(repeat_mult, axis=0)
         | 
| 234 | 
            +
                    .repeat(repeat_mult, axis=1)[:, :, None]
         | 
| 235 | 
            +
                    .repeat(3, axis=-1)
         | 
| 236 | 
            +
                )
         | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
            def remove_background(input_image: Image) -> Image:
         | 
| 240 | 
            +
                return bg_remover.process(input_image.convert("RGB"))
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
| 243 | 
            +
            def show_mask_img(input_image: Image) -> Image:
         | 
| 244 | 
            +
                img_numpy = np.array(input_image)
         | 
| 245 | 
            +
                alpha = img_numpy[:, :, 3] / 255.0
         | 
| 246 | 
            +
                chkb = checkerboard(32, 512) * 255
         | 
| 247 | 
            +
                new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
         | 
| 248 | 
            +
                return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
         | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
            def process_model_run(
         | 
| 252 | 
            +
                background_state,
         | 
| 253 | 
            +
                guidance_scale,
         | 
| 254 | 
            +
                random_seed,
         | 
| 255 | 
            +
                pc_cond,
         | 
| 256 | 
            +
                remesh_option,
         | 
| 257 | 
            +
                vertex_count_type,
         | 
| 258 | 
            +
                vertex_count,
         | 
| 259 | 
            +
                texture_resolution,
         | 
| 260 | 
            +
            ):
         | 
| 261 | 
            +
                # Adjust vertex count based on selection
         | 
| 262 | 
            +
                final_vertex_count = (
         | 
| 263 | 
            +
                    -1
         | 
| 264 | 
            +
                    if vertex_count_type == "Keep Vertex Count"
         | 
| 265 | 
            +
                    else (
         | 
| 266 | 
            +
                        vertex_count // 2
         | 
| 267 | 
            +
                        if vertex_count_type == "Target Face Count"
         | 
| 268 | 
            +
                        else vertex_count
         | 
| 269 | 
            +
                    )
         | 
| 270 | 
            +
                )
         | 
| 271 | 
            +
                print(
         | 
| 272 | 
            +
                    f"Final vertex count: {final_vertex_count} with type {vertex_count_type} and vertex count {vertex_count}"
         | 
| 273 | 
            +
                )
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                glb_file, pc_file, pc_plot = run_model(
         | 
| 276 | 
            +
                    background_state,
         | 
| 277 | 
            +
                    guidance_scale,
         | 
| 278 | 
            +
                    random_seed,
         | 
| 279 | 
            +
                    pc_cond,
         | 
| 280 | 
            +
                    remesh_option,
         | 
| 281 | 
            +
                    final_vertex_count,
         | 
| 282 | 
            +
                    texture_resolution,
         | 
| 283 | 
            +
                )
         | 
| 284 | 
            +
                # Create a single float list of x y z r g b
         | 
| 285 | 
            +
                point_list = []
         | 
| 286 | 
            +
                for i in range(pc_plot.vertices.shape[0]):
         | 
| 287 | 
            +
                    point_list.extend(
         | 
| 288 | 
            +
                        [
         | 
| 289 | 
            +
                            pc_plot.vertices[i, 0],
         | 
| 290 | 
            +
                            pc_plot.vertices[i, 1],
         | 
| 291 | 
            +
                            pc_plot.vertices[i, 2],
         | 
| 292 | 
            +
                            pc_plot.colors[i, 0] / 255.0,
         | 
| 293 | 
            +
                            pc_plot.colors[i, 1] / 255.0,
         | 
| 294 | 
            +
                            pc_plot.colors[i, 2] / 255.0,
         | 
| 295 | 
            +
                        ]
         | 
| 296 | 
            +
                    )
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                return glb_file, pc_file, point_list
         | 
| 299 | 
            +
             | 
| 300 | 
            +
             | 
| 301 | 
            +
            def regenerate_run(
         | 
| 302 | 
            +
                background_state,
         | 
| 303 | 
            +
                guidance_scale,
         | 
| 304 | 
            +
                random_seed,
         | 
| 305 | 
            +
                pc_cond,
         | 
| 306 | 
            +
                remesh_option,
         | 
| 307 | 
            +
                vertex_count_type,
         | 
| 308 | 
            +
                vertex_count,
         | 
| 309 | 
            +
                texture_resolution,
         | 
| 310 | 
            +
            ):
         | 
| 311 | 
            +
                glb_file, pc_file, point_list = process_model_run(
         | 
| 312 | 
            +
                    background_state,
         | 
| 313 | 
            +
                    guidance_scale,
         | 
| 314 | 
            +
                    random_seed,
         | 
| 315 | 
            +
                    pc_cond,
         | 
| 316 | 
            +
                    remesh_option,
         | 
| 317 | 
            +
                    vertex_count_type,
         | 
| 318 | 
            +
                    vertex_count,
         | 
| 319 | 
            +
                    texture_resolution,
         | 
| 320 | 
            +
                )
         | 
| 321 | 
            +
                return (
         | 
| 322 | 
            +
                    gr.update(),  # run_btn
         | 
| 323 | 
            +
                    gr.update(),  # img_proc_state
         | 
| 324 | 
            +
                    gr.update(),  # background_remove_state
         | 
| 325 | 
            +
                    gr.update(),  # preview_removal
         | 
| 326 | 
            +
                    gr.update(value=glb_file, visible=True),  # output_3d
         | 
| 327 | 
            +
                    gr.update(visible=True),  # hdr_row
         | 
| 328 | 
            +
                    gr.update(visible=True),  # point_cloud_row
         | 
| 329 | 
            +
                    gr.update(value=point_list),  # point_cloud_editor
         | 
| 330 | 
            +
                    gr.update(value=pc_file),  # pc_download
         | 
| 331 | 
            +
                    gr.update(visible=False),  # regenerate_btn
         | 
| 332 | 
            +
                )
         | 
| 333 | 
            +
             | 
| 334 | 
            +
             | 
| 335 | 
            +
            def run_button(
         | 
| 336 | 
            +
                run_btn,
         | 
| 337 | 
            +
                input_image,
         | 
| 338 | 
            +
                background_state,
         | 
| 339 | 
            +
                foreground_ratio,
         | 
| 340 | 
            +
                no_crop,
         | 
| 341 | 
            +
                guidance_scale,
         | 
| 342 | 
            +
                random_seed,
         | 
| 343 | 
            +
                pc_upload,
         | 
| 344 | 
            +
                pc_cond_file,
         | 
| 345 | 
            +
                remesh_option,
         | 
| 346 | 
            +
                vertex_count_type,
         | 
| 347 | 
            +
                vertex_count,
         | 
| 348 | 
            +
                texture_resolution,
         | 
| 349 | 
            +
            ):
         | 
| 350 | 
            +
                if run_btn == "Run":
         | 
| 351 | 
            +
                    if torch.cuda.is_available():
         | 
| 352 | 
            +
                        torch.cuda.reset_peak_memory_stats()
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    if pc_upload:
         | 
| 355 | 
            +
                        # make sure the pc_cond_file has been uploaded
         | 
| 356 | 
            +
                        try:
         | 
| 357 | 
            +
                            pc_cond = trimesh.load(pc_cond_file.name)
         | 
| 358 | 
            +
                        except Exception:
         | 
| 359 | 
            +
                            raise gr.Error(
         | 
| 360 | 
            +
                                "Please upload a valid point cloud ply file as condition."
         | 
| 361 | 
            +
                            )
         | 
| 362 | 
            +
                    else:
         | 
| 363 | 
            +
                        pc_cond = None
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                    glb_file, pc_file, pc_list = process_model_run(
         | 
| 366 | 
            +
                        background_state,
         | 
| 367 | 
            +
                        guidance_scale,
         | 
| 368 | 
            +
                        random_seed,
         | 
| 369 | 
            +
                        pc_cond,
         | 
| 370 | 
            +
                        remesh_option,
         | 
| 371 | 
            +
                        vertex_count_type,
         | 
| 372 | 
            +
                        vertex_count,
         | 
| 373 | 
            +
                        texture_resolution,
         | 
| 374 | 
            +
                    )
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    if torch.cuda.is_available():
         | 
| 377 | 
            +
                        print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
         | 
| 378 | 
            +
                    elif torch.backends.mps.is_available():
         | 
| 379 | 
            +
                        print(
         | 
| 380 | 
            +
                            "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
         | 
| 381 | 
            +
                        )
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    return (
         | 
| 384 | 
            +
                        gr.update(),  # run_btn
         | 
| 385 | 
            +
                        gr.update(),  # img_proc_state
         | 
| 386 | 
            +
                        gr.update(),  # background_remove_state
         | 
| 387 | 
            +
                        gr.update(),  # preview_removal
         | 
| 388 | 
            +
                        gr.update(value=glb_file, visible=True),  # output_3d
         | 
| 389 | 
            +
                        gr.update(visible=True),  # hdr_row
         | 
| 390 | 
            +
                        gr.update(visible=True),  # point_cloud_row
         | 
| 391 | 
            +
                        gr.update(value=pc_list),  # point_cloud_editor
         | 
| 392 | 
            +
                        gr.update(value=pc_file),  # pc_download
         | 
| 393 | 
            +
                        gr.update(visible=False),  # regenerate_btn
         | 
| 394 | 
            +
                    )
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                elif run_btn == "Remove Background":
         | 
| 397 | 
            +
                    rem_removed = remove_background(input_image)
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    fr_res = spar3d_utils.foreground_crop(
         | 
| 400 | 
            +
                        rem_removed,
         | 
| 401 | 
            +
                        crop_ratio=foreground_ratio,
         | 
| 402 | 
            +
                        newsize=(COND_WIDTH, COND_HEIGHT),
         | 
| 403 | 
            +
                        no_crop=no_crop,
         | 
| 404 | 
            +
                    )
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    return (
         | 
| 407 | 
            +
                        gr.update(value="Run", visible=True),  # run_btn
         | 
| 408 | 
            +
                        rem_removed,  # img_proc_state,
         | 
| 409 | 
            +
                        fr_res,  # background_remove_state
         | 
| 410 | 
            +
                        gr.update(value=show_mask_img(fr_res), visible=True),  # preview_removal
         | 
| 411 | 
            +
                        gr.update(value=None, visible=False),  # output_3d
         | 
| 412 | 
            +
                        gr.update(visible=False),  # hdr_row
         | 
| 413 | 
            +
                        gr.update(visible=False),  # point_cloud_row
         | 
| 414 | 
            +
                        gr.update(value=None),  # point_cloud_editor
         | 
| 415 | 
            +
                        gr.update(value=None),  # pc_download
         | 
| 416 | 
            +
                        gr.update(visible=False),  # regenerate_btn
         | 
| 417 | 
            +
                    )
         | 
| 418 | 
            +
             | 
| 419 | 
            +
             | 
| 420 | 
            +
            def requires_bg_remove(image, fr, no_crop):
         | 
| 421 | 
            +
                if image is None:
         | 
| 422 | 
            +
                    return (
         | 
| 423 | 
            +
                        gr.update(visible=False, value="Run"),  # run_Btn
         | 
| 424 | 
            +
                        None,  # img_proc_state
         | 
| 425 | 
            +
                        None,  # background_remove_state
         | 
| 426 | 
            +
                        gr.update(value=None, visible=False),  # preview_removal
         | 
| 427 | 
            +
                        gr.update(value=None, visible=False),  # output_3d
         | 
| 428 | 
            +
                        gr.update(visible=False),  # hdr_row
         | 
| 429 | 
            +
                        gr.update(visible=False),  # point_cloud_row
         | 
| 430 | 
            +
                        gr.update(value=None),  # point_cloud_editor
         | 
| 431 | 
            +
                        gr.update(value=None),  # pc_download
         | 
| 432 | 
            +
                        gr.update(visible=False),  # regenerate_btn
         | 
| 433 | 
            +
                    )
         | 
| 434 | 
            +
                alpha_channel = np.array(image.getchannel("A"))
         | 
| 435 | 
            +
                min_alpha = alpha_channel.min()
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                if min_alpha == 0:
         | 
| 438 | 
            +
                    print("Already has alpha")
         | 
| 439 | 
            +
                    fr_res = spar3d_utils.foreground_crop(
         | 
| 440 | 
            +
                        image, fr, newsize=(COND_WIDTH, COND_HEIGHT), no_crop=no_crop
         | 
| 441 | 
            +
                    )
         | 
| 442 | 
            +
                    return (
         | 
| 443 | 
            +
                        gr.update(value="Run", visible=True),  # run_Btn
         | 
| 444 | 
            +
                        image,  # img_proc_state
         | 
| 445 | 
            +
                        fr_res,  # background_remove_state
         | 
| 446 | 
            +
                        gr.update(value=show_mask_img(fr_res), visible=True),  # preview_removal
         | 
| 447 | 
            +
                        gr.update(value=None, visible=False),  # output_3d
         | 
| 448 | 
            +
                        gr.update(visible=False),  # hdr_row
         | 
| 449 | 
            +
                        gr.update(visible=False),  # point_cloud_row
         | 
| 450 | 
            +
                        gr.update(value=None),  # point_cloud_editor
         | 
| 451 | 
            +
                        gr.update(value=None),  # pc_download
         | 
| 452 | 
            +
                        gr.update(visible=False),  # regenerate_btn
         | 
| 453 | 
            +
                    )
         | 
| 454 | 
            +
                return (
         | 
| 455 | 
            +
                    gr.update(value="Remove Background", visible=True),  # run_Btn
         | 
| 456 | 
            +
                    None,  # img_proc_state
         | 
| 457 | 
            +
                    None,  # background_remove_state
         | 
| 458 | 
            +
                    gr.update(value=None, visible=False),  # preview_removal
         | 
| 459 | 
            +
                    gr.update(value=None, visible=False),  # output_3d
         | 
| 460 | 
            +
                    gr.update(visible=False),  # hdr_row
         | 
| 461 | 
            +
                    gr.update(visible=False),  # point_cloud_row
         | 
| 462 | 
            +
                    gr.update(value=None),  # point_cloud_editor
         | 
| 463 | 
            +
                    gr.update(value=None),  # pc_download
         | 
| 464 | 
            +
                    gr.update(visible=False),  # regenerate_btn
         | 
| 465 | 
            +
                )
         | 
| 466 | 
            +
             | 
| 467 | 
            +
             | 
| 468 | 
            +
            def update_foreground_ratio(img_proc, fr, no_crop):
         | 
| 469 | 
            +
                foreground_res = spar3d_utils.foreground_crop(
         | 
| 470 | 
            +
                    img_proc, fr, newsize=(COND_WIDTH, COND_HEIGHT), no_crop=no_crop
         | 
| 471 | 
            +
                )
         | 
| 472 | 
            +
                return (
         | 
| 473 | 
            +
                    foreground_res,
         | 
| 474 | 
            +
                    gr.update(value=show_mask_img(foreground_res)),
         | 
| 475 | 
            +
                )
         | 
| 476 | 
            +
             | 
| 477 | 
            +
             | 
| 478 | 
            +
            def update_resolution_controls(remesh_choice, vertex_count_type):
         | 
| 479 | 
            +
                show_controls = remesh_choice.lower() != "none"
         | 
| 480 | 
            +
                show_vertex_count = vertex_count_type != "Keep Vertex Count"
         | 
| 481 | 
            +
                return (
         | 
| 482 | 
            +
                    gr.update(visible=show_controls),  # vertex_count_type
         | 
| 483 | 
            +
                    gr.update(visible=show_controls and show_vertex_count),  # vertex_count_slider
         | 
| 484 | 
            +
                )
         | 
| 485 | 
            +
             | 
| 486 | 
            +
             | 
| 487 | 
            +
            with gr.Blocks() as demo:
         | 
| 488 | 
            +
                img_proc_state = gr.State()
         | 
| 489 | 
            +
                background_remove_state = gr.State()
         | 
| 490 | 
            +
                gr.Markdown(
         | 
| 491 | 
            +
                    """
         | 
| 492 | 
            +
                # SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                SPAR3D is a state-of-the-art method for 3D mesh reconstruction from a single image. This demo allows you to upload an image and generate a 3D mesh model from it. A feature of SPAR3D is it generates point clouds as intermediate representation before producing the mesh. You can edit the point cloud to adjust the final mesh. We provide a simple point cloud editor in this demo, where you can drag, recolor and rescale the point clouds. If you have more advanced editing needs (e.g. box selection, duplication, local streching, etc.), you can download the point cloud and edit it in softwares such as MeshLab or Blender. The edited point cloud can then be uploaded to this demo to generate a new 3D model by checking the "Point cloud upload" box.
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                **Tips**
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                1. If the image does not have a valid alpha channel, it will go through the background removal step. Our built-in background removal can be inaccurate sometimes, which will result in poor mesh quality. In such cases, you can use external background removal tools to obtain a RGBA image before uploading here.
         | 
| 499 | 
            +
                2. You can adjust the foreground ratio to control the size of the foreground object. This may have major impact on the final mesh.
         | 
| 500 | 
            +
                3. Guidance scale controls the strength of the image condition in the point cloud generation process. A higher value may result in higher mesh fidelity, but the variability by changing the random seed will be lower. Note that the guidance scale and the seed are not effective when the point cloud is manually uploaded.
         | 
| 501 | 
            +
                4. Our online editor supports multi-selection by holding down the shift key. This allows you to recolor multiple points at once.
         | 
| 502 | 
            +
                5. The editing should mainly alter the unseen parts of the object. Visible parts can be edited, but the edits should be consistent with the image. Editing the visible parts in a way that contradicts the image may result in poor mesh quality.
         | 
| 503 | 
            +
                6. You can upload your own HDR environment map to light the 3D model.
         | 
| 504 | 
            +
                """
         | 
| 505 | 
            +
                )
         | 
| 506 | 
            +
                with gr.Row(variant="panel"):
         | 
| 507 | 
            +
                    with gr.Column():
         | 
| 508 | 
            +
                        with gr.Row():
         | 
| 509 | 
            +
                            input_img = gr.Image(
         | 
| 510 | 
            +
                                type="pil", label="Input Image", sources="upload", image_mode="RGBA"
         | 
| 511 | 
            +
                            )
         | 
| 512 | 
            +
                            preview_removal = gr.Image(
         | 
| 513 | 
            +
                                label="Preview Background Removal",
         | 
| 514 | 
            +
                                type="pil",
         | 
| 515 | 
            +
                                image_mode="RGB",
         | 
| 516 | 
            +
                                interactive=False,
         | 
| 517 | 
            +
                                visible=False,
         | 
| 518 | 
            +
                            )
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                        gr.Markdown("### Input Controls")
         | 
| 521 | 
            +
                        with gr.Group():
         | 
| 522 | 
            +
                            with gr.Row():
         | 
| 523 | 
            +
                                no_crop = gr.Checkbox(label="No cropping", value=False)
         | 
| 524 | 
            +
                                pc_upload = gr.Checkbox(label="Point cloud upload", value=False)
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                            pc_cond_file = gr.File(
         | 
| 527 | 
            +
                                label="Point Cloud Upload",
         | 
| 528 | 
            +
                                file_types=[".ply"],
         | 
| 529 | 
            +
                                file_count="single",
         | 
| 530 | 
            +
                                visible=False,
         | 
| 531 | 
            +
                            )
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                            foreground_ratio = gr.Slider(
         | 
| 534 | 
            +
                                label="Padding Ratio",
         | 
| 535 | 
            +
                                minimum=1.0,
         | 
| 536 | 
            +
                                maximum=2.0,
         | 
| 537 | 
            +
                                value=1.3,
         | 
| 538 | 
            +
                                step=0.05,
         | 
| 539 | 
            +
                            )
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                        pc_upload.change(
         | 
| 542 | 
            +
                            lambda x: gr.update(visible=x),
         | 
| 543 | 
            +
                            inputs=pc_upload,
         | 
| 544 | 
            +
                            outputs=[pc_cond_file],
         | 
| 545 | 
            +
                        )
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                        no_crop.change(
         | 
| 548 | 
            +
                            update_foreground_ratio,
         | 
| 549 | 
            +
                            inputs=[img_proc_state, foreground_ratio, no_crop],
         | 
| 550 | 
            +
                            outputs=[background_remove_state, preview_removal],
         | 
| 551 | 
            +
                        )
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                        foreground_ratio.change(
         | 
| 554 | 
            +
                            update_foreground_ratio,
         | 
| 555 | 
            +
                            inputs=[img_proc_state, foreground_ratio, no_crop],
         | 
| 556 | 
            +
                            outputs=[background_remove_state, preview_removal],
         | 
| 557 | 
            +
                        )
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                        gr.Markdown("### Point Diffusion Controls")
         | 
| 560 | 
            +
                        with gr.Group():
         | 
| 561 | 
            +
                            guidance_scale = gr.Slider(
         | 
| 562 | 
            +
                                label="Guidance Scale",
         | 
| 563 | 
            +
                                minimum=1.0,
         | 
| 564 | 
            +
                                maximum=10.0,
         | 
| 565 | 
            +
                                value=3.0,
         | 
| 566 | 
            +
                                step=1.0,
         | 
| 567 | 
            +
                            )
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                            random_seed = gr.Slider(
         | 
| 570 | 
            +
                                label="Seed",
         | 
| 571 | 
            +
                                minimum=0,
         | 
| 572 | 
            +
                                maximum=10000,
         | 
| 573 | 
            +
                                value=0,
         | 
| 574 | 
            +
                                step=1,
         | 
| 575 | 
            +
                            )
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                        no_remesh = not TRIANGLE_REMESH_AVAILABLE and not QUAD_REMESH_AVAILABLE
         | 
| 578 | 
            +
                        gr.Markdown(
         | 
| 579 | 
            +
                            "### Texture Controls"
         | 
| 580 | 
            +
                            if no_remesh
         | 
| 581 | 
            +
                            else "### Meshing and Texture Controls"
         | 
| 582 | 
            +
                        )
         | 
| 583 | 
            +
                        with gr.Group():
         | 
| 584 | 
            +
                            remesh_choices = ["None"]
         | 
| 585 | 
            +
                            if TRIANGLE_REMESH_AVAILABLE:
         | 
| 586 | 
            +
                                remesh_choices.append("Triangle")
         | 
| 587 | 
            +
                            if QUAD_REMESH_AVAILABLE:
         | 
| 588 | 
            +
                                remesh_choices.append("Quad")
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                            remesh_option = gr.Radio(
         | 
| 591 | 
            +
                                choices=remesh_choices,
         | 
| 592 | 
            +
                                label="Remeshing",
         | 
| 593 | 
            +
                                value="None",
         | 
| 594 | 
            +
                                visible=not no_remesh,
         | 
| 595 | 
            +
                            )
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                            vertex_count_type = gr.Radio(
         | 
| 598 | 
            +
                                choices=[
         | 
| 599 | 
            +
                                    "Keep Vertex Count",
         | 
| 600 | 
            +
                                    "Target Vertex Count",
         | 
| 601 | 
            +
                                    "Target Face Count",
         | 
| 602 | 
            +
                                ],
         | 
| 603 | 
            +
                                label="Mesh Resolution Control",
         | 
| 604 | 
            +
                                value="Keep Vertex Count",
         | 
| 605 | 
            +
                                visible=False,
         | 
| 606 | 
            +
                            )
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                            vertex_count_slider = gr.Slider(
         | 
| 609 | 
            +
                                label="Target Count",
         | 
| 610 | 
            +
                                minimum=0,
         | 
| 611 | 
            +
                                maximum=20000,
         | 
| 612 | 
            +
                                value=2000,
         | 
| 613 | 
            +
                                visible=False,
         | 
| 614 | 
            +
                            )
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                            texture_size = gr.Slider(
         | 
| 617 | 
            +
                                label="Texture Size",
         | 
| 618 | 
            +
                                minimum=512,
         | 
| 619 | 
            +
                                maximum=2048,
         | 
| 620 | 
            +
                                value=1024,
         | 
| 621 | 
            +
                                step=256,
         | 
| 622 | 
            +
                                visible=True,
         | 
| 623 | 
            +
                            )
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                        remesh_option.change(
         | 
| 626 | 
            +
                            update_resolution_controls,
         | 
| 627 | 
            +
                            inputs=[remesh_option, vertex_count_type],
         | 
| 628 | 
            +
                            outputs=[vertex_count_type, vertex_count_slider],
         | 
| 629 | 
            +
                        )
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                        vertex_count_type.change(
         | 
| 632 | 
            +
                            update_resolution_controls,
         | 
| 633 | 
            +
                            inputs=[remesh_option, vertex_count_type],
         | 
| 634 | 
            +
                            outputs=[vertex_count_type, vertex_count_slider],
         | 
| 635 | 
            +
                        )
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                        run_btn = gr.Button("Run", variant="primary", visible=False)
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                    with gr.Column():
         | 
| 640 | 
            +
                        with gr.Group(visible=False) as point_cloud_row:
         | 
| 641 | 
            +
                            point_size_slider = gr.Slider(
         | 
| 642 | 
            +
                                label="Point Size",
         | 
| 643 | 
            +
                                minimum=0.01,
         | 
| 644 | 
            +
                                maximum=1.0,
         | 
| 645 | 
            +
                                value=0.2,
         | 
| 646 | 
            +
                                step=0.01,
         | 
| 647 | 
            +
                            )
         | 
| 648 | 
            +
                            point_cloud_editor = PointCloudEditor(
         | 
| 649 | 
            +
                                up_axis="Z",
         | 
| 650 | 
            +
                                forward_axis="X",
         | 
| 651 | 
            +
                                lock_scale_z=True,
         | 
| 652 | 
            +
                                lock_scale_y=True,
         | 
| 653 | 
            +
                                visible=True,
         | 
| 654 | 
            +
                            )
         | 
| 655 | 
            +
             | 
| 656 | 
            +
                            pc_download = gr.File(
         | 
| 657 | 
            +
                                label="Point Cloud Download",
         | 
| 658 | 
            +
                                file_types=[".ply"],
         | 
| 659 | 
            +
                                file_count="single",
         | 
| 660 | 
            +
                            )
         | 
| 661 | 
            +
                        point_size_slider.change(
         | 
| 662 | 
            +
                            fn=lambda x: gr.update(point_size=x),
         | 
| 663 | 
            +
                            inputs=point_size_slider,
         | 
| 664 | 
            +
                            outputs=point_cloud_editor,
         | 
| 665 | 
            +
                        )
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                        regenerate_btn = gr.Button(
         | 
| 668 | 
            +
                            "Re-run with point cloud", variant="primary", visible=False
         | 
| 669 | 
            +
                        )
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                        output_3d = LitModel3D(
         | 
| 672 | 
            +
                            label="3D Model",
         | 
| 673 | 
            +
                            visible=False,
         | 
| 674 | 
            +
                            clear_color=[0.0, 0.0, 0.0, 0.0],
         | 
| 675 | 
            +
                            tonemapping="aces",
         | 
| 676 | 
            +
                            contrast=1.0,
         | 
| 677 | 
            +
                            scale=1.0,
         | 
| 678 | 
            +
                        )
         | 
| 679 | 
            +
                        with gr.Column(visible=False, scale=1.0) as hdr_row:
         | 
| 680 | 
            +
                            gr.Markdown(
         | 
| 681 | 
            +
                                """## HDR Environment Map
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                            Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
         | 
| 684 | 
            +
                            """
         | 
| 685 | 
            +
                            )
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                            with gr.Row():
         | 
| 688 | 
            +
                                hdr_illumination_file = gr.File(
         | 
| 689 | 
            +
                                    label="HDR Env Map",
         | 
| 690 | 
            +
                                    file_types=[".hdr"],
         | 
| 691 | 
            +
                                    file_count="single",
         | 
| 692 | 
            +
                                )
         | 
| 693 | 
            +
                                example_hdris = [
         | 
| 694 | 
            +
                                    os.path.join("demo_files/hdri", f)
         | 
| 695 | 
            +
                                    for f in os.listdir("demo_files/hdri")
         | 
| 696 | 
            +
                                ]
         | 
| 697 | 
            +
                                hdr_illumination_example = gr.Examples(
         | 
| 698 | 
            +
                                    examples=example_hdris,
         | 
| 699 | 
            +
                                    inputs=hdr_illumination_file,
         | 
| 700 | 
            +
                                )
         | 
| 701 | 
            +
             | 
| 702 | 
            +
                                hdr_illumination_file.change(
         | 
| 703 | 
            +
                                    lambda x: gr.update(env_map=x.name if x is not None else None),
         | 
| 704 | 
            +
                                    inputs=hdr_illumination_file,
         | 
| 705 | 
            +
                                    outputs=[output_3d],
         | 
| 706 | 
            +
                                )
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                examples = gr.Examples(
         | 
| 709 | 
            +
                    examples=example_files, inputs=input_img, examples_per_page=11
         | 
| 710 | 
            +
                )
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                input_img.change(
         | 
| 713 | 
            +
                    requires_bg_remove,
         | 
| 714 | 
            +
                    inputs=[input_img, foreground_ratio, no_crop],
         | 
| 715 | 
            +
                    outputs=[
         | 
| 716 | 
            +
                        run_btn,
         | 
| 717 | 
            +
                        img_proc_state,
         | 
| 718 | 
            +
                        background_remove_state,
         | 
| 719 | 
            +
                        preview_removal,
         | 
| 720 | 
            +
                        output_3d,
         | 
| 721 | 
            +
                        hdr_row,
         | 
| 722 | 
            +
                        point_cloud_row,
         | 
| 723 | 
            +
                        point_cloud_editor,
         | 
| 724 | 
            +
                        pc_download,
         | 
| 725 | 
            +
                        regenerate_btn,
         | 
| 726 | 
            +
                    ],
         | 
| 727 | 
            +
                )
         | 
| 728 | 
            +
             | 
| 729 | 
            +
                point_cloud_editor.edit(
         | 
| 730 | 
            +
                    fn=lambda _x: gr.update(visible=True),
         | 
| 731 | 
            +
                    inputs=point_cloud_editor,
         | 
| 732 | 
            +
                    outputs=regenerate_btn,
         | 
| 733 | 
            +
                )
         | 
| 734 | 
            +
             | 
| 735 | 
            +
                regenerate_btn.click(
         | 
| 736 | 
            +
                    regenerate_run,
         | 
| 737 | 
            +
                    inputs=[
         | 
| 738 | 
            +
                        background_remove_state,
         | 
| 739 | 
            +
                        guidance_scale,
         | 
| 740 | 
            +
                        random_seed,
         | 
| 741 | 
            +
                        point_cloud_editor,
         | 
| 742 | 
            +
                        remesh_option,
         | 
| 743 | 
            +
                        vertex_count_type,
         | 
| 744 | 
            +
                        vertex_count_slider,
         | 
| 745 | 
            +
                        texture_size,
         | 
| 746 | 
            +
                    ],
         | 
| 747 | 
            +
                    outputs=[
         | 
| 748 | 
            +
                        run_btn,
         | 
| 749 | 
            +
                        img_proc_state,
         | 
| 750 | 
            +
                        background_remove_state,
         | 
| 751 | 
            +
                        preview_removal,
         | 
| 752 | 
            +
                        output_3d,
         | 
| 753 | 
            +
                        hdr_row,
         | 
| 754 | 
            +
                        point_cloud_row,
         | 
| 755 | 
            +
                        point_cloud_editor,
         | 
| 756 | 
            +
                        pc_download,
         | 
| 757 | 
            +
                        regenerate_btn,
         | 
| 758 | 
            +
                    ],
         | 
| 759 | 
            +
                )
         | 
| 760 | 
            +
             | 
| 761 | 
            +
                run_btn.click(
         | 
| 762 | 
            +
                    run_button,
         | 
| 763 | 
            +
                    inputs=[
         | 
| 764 | 
            +
                        run_btn,
         | 
| 765 | 
            +
                        input_img,
         | 
| 766 | 
            +
                        background_remove_state,
         | 
| 767 | 
            +
                        foreground_ratio,
         | 
| 768 | 
            +
                        no_crop,
         | 
| 769 | 
            +
                        guidance_scale,
         | 
| 770 | 
            +
                        random_seed,
         | 
| 771 | 
            +
                        pc_upload,
         | 
| 772 | 
            +
                        pc_cond_file,
         | 
| 773 | 
            +
                        remesh_option,
         | 
| 774 | 
            +
                        vertex_count_type,
         | 
| 775 | 
            +
                        vertex_count_slider,
         | 
| 776 | 
            +
                        texture_size,
         | 
| 777 | 
            +
                    ],
         | 
| 778 | 
            +
                    outputs=[
         | 
| 779 | 
            +
                        run_btn,
         | 
| 780 | 
            +
                        img_proc_state,
         | 
| 781 | 
            +
                        background_remove_state,
         | 
| 782 | 
            +
                        preview_removal,
         | 
| 783 | 
            +
                        output_3d,
         | 
| 784 | 
            +
                        hdr_row,
         | 
| 785 | 
            +
                        point_cloud_row,
         | 
| 786 | 
            +
                        point_cloud_editor,
         | 
| 787 | 
            +
                        pc_download,
         | 
| 788 | 
            +
                        regenerate_btn,
         | 
| 789 | 
            +
                    ],
         | 
| 790 | 
            +
                )
         | 
| 791 | 
            +
             | 
| 792 | 
            +
            demo.queue().launch()
         | 
    	
        load/tets/160_tets.npz
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9
         | 
| 3 | 
            +
            size 15408790
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            einops==0.7.0
         | 
| 2 | 
            +
            jaxtyping==0.2.31
         | 
| 3 | 
            +
            omegaconf==2.3.0
         | 
| 4 | 
            +
            transformers==4.42.3
         | 
| 5 | 
            +
            loralib==0.1.2
         | 
| 6 | 
            +
            git+https://github.com/openai/CLIP.git
         | 
| 7 | 
            +
            git+https://github.com/SunzeY/AlphaCLIP.git
         | 
| 8 | 
            +
            trimesh==4.4.1
         | 
| 9 | 
            +
            numpy==1.26.4
         | 
| 10 | 
            +
            huggingface-hub==0.23.4
         | 
| 11 | 
            +
            transparent-background==1.3.3
         | 
| 12 | 
            +
            gradio==4.43.0
         | 
| 13 | 
            +
            gradio-litmodel3d==0.0.1
         | 
| 14 | 
            +
            gradio-pointcloudeditor==0.0.9
         | 
| 15 | 
            +
            gpytoolbox==0.2.0
         | 
| 16 | 
            +
            # ./texture_baker/
         | 
| 17 | 
            +
            # ./uv_unwrapper/
         | 
    	
        ruff.toml
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [lint]
         | 
| 2 | 
            +
            ignore = ["F722", "F821"]
         | 
| 3 | 
            +
            extend-select = ["I"]
         | 
    	
        run.py
    ADDED
    
    | @@ -0,0 +1,180 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from contextlib import nullcontext
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from PIL import Image
         | 
| 7 | 
            +
            from tqdm import tqdm
         | 
| 8 | 
            +
            from transparent_background import Remover
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
         | 
| 11 | 
            +
            from spar3d.system import SPAR3D
         | 
| 12 | 
            +
            from spar3d.utils import foreground_crop, get_device, remove_background
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def check_positive(value):
         | 
| 16 | 
            +
                ivalue = int(value)
         | 
| 17 | 
            +
                if ivalue <= 0:
         | 
| 18 | 
            +
                    raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value)
         | 
| 19 | 
            +
                return ivalue
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            if __name__ == "__main__":
         | 
| 23 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 24 | 
            +
                parser.add_argument(
         | 
| 25 | 
            +
                    "image", type=str, nargs="+", help="Path to input image(s) or folder."
         | 
| 26 | 
            +
                )
         | 
| 27 | 
            +
                parser.add_argument(
         | 
| 28 | 
            +
                    "--device",
         | 
| 29 | 
            +
                    default=get_device(),
         | 
| 30 | 
            +
                    type=str,
         | 
| 31 | 
            +
                    help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'",
         | 
| 32 | 
            +
                )
         | 
| 33 | 
            +
                parser.add_argument(
         | 
| 34 | 
            +
                    "--pretrained-model",
         | 
| 35 | 
            +
                    default="stabilityai/spar3d",
         | 
| 36 | 
            +
                    type=str,
         | 
| 37 | 
            +
                    help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/spar3d'",
         | 
| 38 | 
            +
                )
         | 
| 39 | 
            +
                parser.add_argument(
         | 
| 40 | 
            +
                    "--foreground-ratio",
         | 
| 41 | 
            +
                    default=1.3,
         | 
| 42 | 
            +
                    type=float,
         | 
| 43 | 
            +
                    help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
         | 
| 44 | 
            +
                )
         | 
| 45 | 
            +
                parser.add_argument(
         | 
| 46 | 
            +
                    "--output-dir",
         | 
| 47 | 
            +
                    default="output/",
         | 
| 48 | 
            +
                    type=str,
         | 
| 49 | 
            +
                    help="Output directory to save the results. Default: 'output/'",
         | 
| 50 | 
            +
                )
         | 
| 51 | 
            +
                parser.add_argument(
         | 
| 52 | 
            +
                    "--texture-resolution",
         | 
| 53 | 
            +
                    default=1024,
         | 
| 54 | 
            +
                    type=int,
         | 
| 55 | 
            +
                    help="Texture atlas resolution. Default: 1024",
         | 
| 56 | 
            +
                )
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                remesh_choices = ["none"]
         | 
| 59 | 
            +
                if TRIANGLE_REMESH_AVAILABLE:
         | 
| 60 | 
            +
                    remesh_choices.append("triangle")
         | 
| 61 | 
            +
                if QUAD_REMESH_AVAILABLE:
         | 
| 62 | 
            +
                    remesh_choices.append("quad")
         | 
| 63 | 
            +
                parser.add_argument(
         | 
| 64 | 
            +
                    "--remesh_option",
         | 
| 65 | 
            +
                    choices=remesh_choices,
         | 
| 66 | 
            +
                    default="none",
         | 
| 67 | 
            +
                    help="Remeshing option",
         | 
| 68 | 
            +
                )
         | 
| 69 | 
            +
                if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE:
         | 
| 70 | 
            +
                    parser.add_argument(
         | 
| 71 | 
            +
                        "--reduction_count_type",
         | 
| 72 | 
            +
                        choices=["keep", "vertex", "faces"],
         | 
| 73 | 
            +
                        default="keep",
         | 
| 74 | 
            +
                        help="Vertex count type",
         | 
| 75 | 
            +
                    )
         | 
| 76 | 
            +
                    parser.add_argument(
         | 
| 77 | 
            +
                        "--target_count",
         | 
| 78 | 
            +
                        type=check_positive,
         | 
| 79 | 
            +
                        help="Selected target count.",
         | 
| 80 | 
            +
                        default=2000,
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
                parser.add_argument(
         | 
| 83 | 
            +
                    "--batch_size", default=1, type=int, help="Batch size for inference"
         | 
| 84 | 
            +
                )
         | 
| 85 | 
            +
                args = parser.parse_args()
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                # Ensure args.device contains cuda
         | 
| 88 | 
            +
                devices = ["cuda", "mps", "cpu"]
         | 
| 89 | 
            +
                if not any(args.device in device for device in devices):
         | 
| 90 | 
            +
                    raise ValueError("Invalid device. Use cuda, mps or cpu")
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                output_dir = args.output_dir
         | 
| 93 | 
            +
                os.makedirs(output_dir, exist_ok=True)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                device = args.device
         | 
| 96 | 
            +
                if not (torch.cuda.is_available() or torch.backends.mps.is_available()):
         | 
| 97 | 
            +
                    device = "cpu"
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                print("Device used: ", device)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                model = SPAR3D.from_pretrained(
         | 
| 102 | 
            +
                    args.pretrained_model,
         | 
| 103 | 
            +
                    config_name="config.yaml",
         | 
| 104 | 
            +
                    weight_name="model.safetensors",
         | 
| 105 | 
            +
                )
         | 
| 106 | 
            +
                model.to(device)
         | 
| 107 | 
            +
                model.eval()
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                bg_remover = Remover(device=device)
         | 
| 110 | 
            +
                images = []
         | 
| 111 | 
            +
                idx = 0
         | 
| 112 | 
            +
                for image_path in args.image:
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    def handle_image(image_path, idx):
         | 
| 115 | 
            +
                        image = remove_background(
         | 
| 116 | 
            +
                            Image.open(image_path).convert("RGBA"), bg_remover
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
                        image = foreground_crop(image, args.foreground_ratio)
         | 
| 119 | 
            +
                        os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True)
         | 
| 120 | 
            +
                        image.save(os.path.join(output_dir, str(idx), "input.png"))
         | 
| 121 | 
            +
                        images.append(image)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    if os.path.isdir(image_path):
         | 
| 124 | 
            +
                        image_paths = [
         | 
| 125 | 
            +
                            os.path.join(image_path, f)
         | 
| 126 | 
            +
                            for f in os.listdir(image_path)
         | 
| 127 | 
            +
                            if f.endswith((".png", ".jpg", ".jpeg"))
         | 
| 128 | 
            +
                        ]
         | 
| 129 | 
            +
                        for image_path in image_paths:
         | 
| 130 | 
            +
                            handle_image(image_path, idx)
         | 
| 131 | 
            +
                            idx += 1
         | 
| 132 | 
            +
                    else:
         | 
| 133 | 
            +
                        handle_image(image_path, idx)
         | 
| 134 | 
            +
                        idx += 1
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                vertex_count = (
         | 
| 137 | 
            +
                    -1
         | 
| 138 | 
            +
                    if args.reduction_count_type == "keep"
         | 
| 139 | 
            +
                    else (
         | 
| 140 | 
            +
                        args.target_count
         | 
| 141 | 
            +
                        if args.reduction_count_type == "vertex"
         | 
| 142 | 
            +
                        else args.target_count // 2
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
                )
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                for i in tqdm(range(0, len(images), args.batch_size)):
         | 
| 147 | 
            +
                    image = images[i : i + args.batch_size]
         | 
| 148 | 
            +
                    if torch.cuda.is_available():
         | 
| 149 | 
            +
                        torch.cuda.reset_peak_memory_stats()
         | 
| 150 | 
            +
                    with torch.no_grad():
         | 
| 151 | 
            +
                        with (
         | 
| 152 | 
            +
                            torch.autocast(device_type=device, dtype=torch.float16)
         | 
| 153 | 
            +
                            if "cuda" in device
         | 
| 154 | 
            +
                            else nullcontext()
         | 
| 155 | 
            +
                        ):
         | 
| 156 | 
            +
                            mesh, glob_dict = model.run_image(
         | 
| 157 | 
            +
                                image,
         | 
| 158 | 
            +
                                bake_resolution=args.texture_resolution,
         | 
| 159 | 
            +
                                remesh=args.remesh_option,
         | 
| 160 | 
            +
                                vertex_count=args.target_vertex_count,
         | 
| 161 | 
            +
                                return_points=True,
         | 
| 162 | 
            +
                            )
         | 
| 163 | 
            +
                    if torch.cuda.is_available():
         | 
| 164 | 
            +
                        print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
         | 
| 165 | 
            +
                    elif torch.backends.mps.is_available():
         | 
| 166 | 
            +
                        print(
         | 
| 167 | 
            +
                            "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
         | 
| 168 | 
            +
                        )
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    if len(image) == 1:
         | 
| 171 | 
            +
                        out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb")
         | 
| 172 | 
            +
                        mesh.export(out_mesh_path, include_normals=True)
         | 
| 173 | 
            +
                        out_points_path = os.path.join(output_dir, str(i), "points.ply")
         | 
| 174 | 
            +
                        glob_dict["point_clouds"][0].export(out_points_path)
         | 
| 175 | 
            +
                    else:
         | 
| 176 | 
            +
                        for j in range(len(mesh)):
         | 
| 177 | 
            +
                            out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb")
         | 
| 178 | 
            +
                            mesh[j].export(out_mesh_path, include_normals=True)
         | 
| 179 | 
            +
                            out_points_path = os.path.join(output_dir, str(i + j), "points.ply")
         | 
| 180 | 
            +
                            glob_dict["point_clouds"][j].export(out_points_path)
         | 
    	
        spar3d/models/camera.py
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass, field
         | 
| 2 | 
            +
            from typing import List
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from spar3d.models.utils import BaseModule
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class LinearCameraEmbedder(BaseModule):
         | 
| 11 | 
            +
                @dataclass
         | 
| 12 | 
            +
                class Config(BaseModule.Config):
         | 
| 13 | 
            +
                    in_channels: int = 25
         | 
| 14 | 
            +
                    out_channels: int = 768
         | 
| 15 | 
            +
                    conditions: List[str] = field(default_factory=list)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                cfg: Config
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def configure(self) -> None:
         | 
| 20 | 
            +
                    self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def forward(self, **kwargs):
         | 
| 23 | 
            +
                    cond_tensors = []
         | 
| 24 | 
            +
                    for cond_name in self.cfg.conditions:
         | 
| 25 | 
            +
                        assert cond_name in kwargs
         | 
| 26 | 
            +
                        cond = kwargs[cond_name]
         | 
| 27 | 
            +
                        # cond in shape (B, Nv, ...)
         | 
| 28 | 
            +
                        cond_tensors.append(cond.view(*cond.shape[:2], -1))
         | 
| 29 | 
            +
                    cond_tensor = torch.cat(cond_tensors, dim=-1)
         | 
| 30 | 
            +
                    assert cond_tensor.shape[-1] == self.cfg.in_channels
         | 
| 31 | 
            +
                    embedding = self.linear(cond_tensor)
         | 
| 32 | 
            +
                    return embedding
         | 
    	
        spar3d/models/diffusion/gaussian_diffusion.py
    ADDED
    
    | @@ -0,0 +1,524 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # Adapted from: https://github.com/openai/point-e
         | 
| 3 | 
            +
            # Licensed under the MIT License
         | 
| 4 | 
            +
            # Copyright (c) 2022 OpenAI
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 7 | 
            +
            # of this software and associated documentation files (the "Software"), to deal
         | 
| 8 | 
            +
            # in the Software without restriction, including without limitation the rights
         | 
| 9 | 
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 10 | 
            +
            # copies of the Software, and to permit persons to whom the Software is
         | 
| 11 | 
            +
            # furnished to do so, subject to the following conditions:
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # The above copyright notice and this permission notice shall be included in all
         | 
| 14 | 
            +
            # copies or substantial portions of the Software.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 17 | 
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 18 | 
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 19 | 
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 20 | 
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 21 | 
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 22 | 
            +
            # SOFTWARE.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            # --------------------------------------------------------
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            import math
         | 
| 27 | 
            +
            from typing import Any, Dict, Iterable, Optional, Sequence, Union
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            import numpy as np
         | 
| 30 | 
            +
            import torch as th
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def sigmoid_schedule(t, start=-3, end=3, tau=0.6, clip_min=1e-9):
         | 
| 34 | 
            +
                def sigmoid(x):
         | 
| 35 | 
            +
                    return 1 / (1 + np.exp(-x))
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                v_start = sigmoid(start / tau)
         | 
| 38 | 
            +
                v_end = sigmoid(end / tau)
         | 
| 39 | 
            +
                output = sigmoid((t * (end - start) + start) / tau)
         | 
| 40 | 
            +
                output = (v_end - output) / (v_end - v_start)
         | 
| 41 | 
            +
                return np.clip(output, clip_min, 1.0)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
                This is the deprecated API for creating beta schedules.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                See get_named_beta_schedule() for the new library of schedules.
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                if beta_schedule == "linear":
         | 
| 51 | 
            +
                    betas = np.linspace(
         | 
| 52 | 
            +
                        beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
         | 
| 53 | 
            +
                    )
         | 
| 54 | 
            +
                else:
         | 
| 55 | 
            +
                    raise NotImplementedError(beta_schedule)
         | 
| 56 | 
            +
                assert betas.shape == (num_diffusion_timesteps,)
         | 
| 57 | 
            +
                return betas
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, exp_p=12):
         | 
| 61 | 
            +
                """
         | 
| 62 | 
            +
                Get a pre-defined beta schedule for the given name.
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                The beta schedule library consists of beta schedules which remain similar
         | 
| 65 | 
            +
                in the limit of num_diffusion_timesteps.
         | 
| 66 | 
            +
                Beta schedules may be added, but should not be removed or changed once
         | 
| 67 | 
            +
                they are committed to maintain backwards compatibility.
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
                if schedule_name == "linear":
         | 
| 70 | 
            +
                    # Linear schedule from Ho et al, extended to work for any number of
         | 
| 71 | 
            +
                    # diffusion steps.
         | 
| 72 | 
            +
                    scale = 1000 / num_diffusion_timesteps
         | 
| 73 | 
            +
                    return get_beta_schedule(
         | 
| 74 | 
            +
                        "linear",
         | 
| 75 | 
            +
                        beta_start=scale * 0.0001,
         | 
| 76 | 
            +
                        beta_end=scale * 0.02,
         | 
| 77 | 
            +
                        num_diffusion_timesteps=num_diffusion_timesteps,
         | 
| 78 | 
            +
                    )
         | 
| 79 | 
            +
                elif schedule_name == "cosine":
         | 
| 80 | 
            +
                    return betas_for_alpha_bar(
         | 
| 81 | 
            +
                        num_diffusion_timesteps,
         | 
| 82 | 
            +
                        lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
         | 
| 83 | 
            +
                    )
         | 
| 84 | 
            +
                elif schedule_name == "sigmoid":
         | 
| 85 | 
            +
                    # Sigmoid schedule passed through betas_for_alpha_bar
         | 
| 86 | 
            +
                    return betas_for_alpha_bar(
         | 
| 87 | 
            +
                        num_diffusion_timesteps, lambda t: sigmoid_schedule(t)
         | 
| 88 | 
            +
                    )
         | 
| 89 | 
            +
                else:
         | 
| 90 | 
            +
                    raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                Create a beta schedule that discretizes the given alpha_t_bar function,
         | 
| 96 | 
            +
                which defines the cumulative product of (1-beta) over time from t = [0,1].
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                :param num_diffusion_timesteps: the number of betas to produce.
         | 
| 99 | 
            +
                :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
         | 
| 100 | 
            +
                                  produces the cumulative product of (1-beta) up to that
         | 
| 101 | 
            +
                                  part of the diffusion process.
         | 
| 102 | 
            +
                :param max_beta: the maximum beta to use; use values lower than 1 to
         | 
| 103 | 
            +
                                 prevent singularities.
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                betas = []
         | 
| 106 | 
            +
                for i in range(num_diffusion_timesteps):
         | 
| 107 | 
            +
                    t1 = i / num_diffusion_timesteps
         | 
| 108 | 
            +
                    t2 = (i + 1) / num_diffusion_timesteps
         | 
| 109 | 
            +
                    betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
         | 
| 110 | 
            +
                return np.array(betas)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def space_timesteps(num_timesteps, section_counts):
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
                Create a list of timesteps to use from an original diffusion process,
         | 
| 116 | 
            +
                given the number of timesteps we want to take from equally-sized portions
         | 
| 117 | 
            +
                of the original process.
         | 
| 118 | 
            +
                For example, if there's 300 timesteps and the section counts are [10,15,20]
         | 
| 119 | 
            +
                then the first 100 timesteps are strided to be 10 timesteps, the second 100
         | 
| 120 | 
            +
                are strided to be 15 timesteps, and the final 100 are strided to be 20.
         | 
| 121 | 
            +
                :param num_timesteps: the number of diffusion steps in the original
         | 
| 122 | 
            +
                                      process to divide up.
         | 
| 123 | 
            +
                :param section_counts: either a list of numbers, or a string containing
         | 
| 124 | 
            +
                                       comma-separated numbers, indicating the step count
         | 
| 125 | 
            +
                                       per section. As a special case, use "ddimN" where N
         | 
| 126 | 
            +
                                       is a number of steps to use the striding from the
         | 
| 127 | 
            +
                                       DDIM paper.
         | 
| 128 | 
            +
                :return: a set of diffusion steps from the original process to use.
         | 
| 129 | 
            +
                """
         | 
| 130 | 
            +
                if isinstance(section_counts, str):
         | 
| 131 | 
            +
                    if section_counts.startswith("ddim"):
         | 
| 132 | 
            +
                        desired_count = int(section_counts[len("ddim") :])
         | 
| 133 | 
            +
                        for i in range(1, num_timesteps):
         | 
| 134 | 
            +
                            if len(range(0, num_timesteps, i)) == desired_count:
         | 
| 135 | 
            +
                                return set(range(0, num_timesteps, i))
         | 
| 136 | 
            +
                        raise ValueError(
         | 
| 137 | 
            +
                            f"cannot create exactly {num_timesteps} steps with an integer stride"
         | 
| 138 | 
            +
                        )
         | 
| 139 | 
            +
                    elif section_counts.startswith("exact"):
         | 
| 140 | 
            +
                        res = set(int(x) for x in section_counts[len("exact") :].split(","))
         | 
| 141 | 
            +
                        for x in res:
         | 
| 142 | 
            +
                            if x < 0 or x >= num_timesteps:
         | 
| 143 | 
            +
                                raise ValueError(f"timestep out of bounds: {x}")
         | 
| 144 | 
            +
                        return res
         | 
| 145 | 
            +
                    section_counts = [int(x) for x in section_counts.split(",")]
         | 
| 146 | 
            +
                size_per = num_timesteps // len(section_counts)
         | 
| 147 | 
            +
                extra = num_timesteps % len(section_counts)
         | 
| 148 | 
            +
                start_idx = 0
         | 
| 149 | 
            +
                all_steps = []
         | 
| 150 | 
            +
                for i, section_count in enumerate(section_counts):
         | 
| 151 | 
            +
                    size = size_per + (1 if i < extra else 0)
         | 
| 152 | 
            +
                    if size < section_count:
         | 
| 153 | 
            +
                        raise ValueError(
         | 
| 154 | 
            +
                            f"cannot divide section of {size} steps into {section_count}"
         | 
| 155 | 
            +
                        )
         | 
| 156 | 
            +
                    if section_count <= 1:
         | 
| 157 | 
            +
                        frac_stride = 1
         | 
| 158 | 
            +
                    else:
         | 
| 159 | 
            +
                        frac_stride = (size - 1) / (section_count - 1)
         | 
| 160 | 
            +
                    cur_idx = 0.0
         | 
| 161 | 
            +
                    taken_steps = []
         | 
| 162 | 
            +
                    for _ in range(section_count):
         | 
| 163 | 
            +
                        taken_steps.append(start_idx + round(cur_idx))
         | 
| 164 | 
            +
                        cur_idx += frac_stride
         | 
| 165 | 
            +
                    all_steps += taken_steps
         | 
| 166 | 
            +
                    start_idx += size
         | 
| 167 | 
            +
                return set(all_steps)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            def _extract_into_tensor(arr, timesteps, broadcast_shape):
         | 
| 171 | 
            +
                """Extract values from a 1-D numpy array for a batch of indices."""
         | 
| 172 | 
            +
                res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
         | 
| 173 | 
            +
                while len(res.shape) < len(broadcast_shape):
         | 
| 174 | 
            +
                    res = res[..., None]
         | 
| 175 | 
            +
                return res + th.zeros(broadcast_shape, device=timesteps.device)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            class GaussianDiffusion:
         | 
| 179 | 
            +
                """
         | 
| 180 | 
            +
                Utilities for sampling from Gaussian diffusion models.
         | 
| 181 | 
            +
                """
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                def __init__(
         | 
| 184 | 
            +
                    self,
         | 
| 185 | 
            +
                    *,
         | 
| 186 | 
            +
                    betas: Sequence[float],
         | 
| 187 | 
            +
                    model_mean_type: str,
         | 
| 188 | 
            +
                    model_var_type: str,
         | 
| 189 | 
            +
                    channel_scales: Optional[np.ndarray] = None,
         | 
| 190 | 
            +
                    channel_biases: Optional[np.ndarray] = None,
         | 
| 191 | 
            +
                ):
         | 
| 192 | 
            +
                    self.model_mean_type = model_mean_type
         | 
| 193 | 
            +
                    self.model_var_type = model_var_type
         | 
| 194 | 
            +
                    self.channel_scales = channel_scales
         | 
| 195 | 
            +
                    self.channel_biases = channel_biases
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    # Use float64 for accuracy
         | 
| 198 | 
            +
                    betas = np.array(betas, dtype=np.float64)
         | 
| 199 | 
            +
                    self.betas = betas
         | 
| 200 | 
            +
                    assert len(betas.shape) == 1, "betas must be 1-D"
         | 
| 201 | 
            +
                    assert (betas > 0).all() and (betas <= 1).all()
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    self.num_timesteps = int(betas.shape[0])
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    alphas = 1.0 - betas
         | 
| 206 | 
            +
                    self.alphas_cumprod = np.cumprod(alphas, axis=0)
         | 
| 207 | 
            +
                    self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    # calculations for diffusion q(x_t | x_{t-1}) and others
         | 
| 210 | 
            +
                    self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
         | 
| 211 | 
            +
                    self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
         | 
| 212 | 
            +
                    self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
         | 
| 213 | 
            +
                    self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
         | 
| 214 | 
            +
                    # calculations for posterior q(x_{t-1} | x_t, x_0)
         | 
| 215 | 
            +
                    self.posterior_variance = (
         | 
| 216 | 
            +
                        betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
         | 
| 217 | 
            +
                    )
         | 
| 218 | 
            +
                    # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
         | 
| 219 | 
            +
                    self.posterior_log_variance_clipped = np.log(
         | 
| 220 | 
            +
                        np.append(self.posterior_variance[1], self.posterior_variance[1:])
         | 
| 221 | 
            +
                    )
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    self.posterior_mean_coef1 = (
         | 
| 224 | 
            +
                        betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
         | 
| 225 | 
            +
                    )
         | 
| 226 | 
            +
                    self.posterior_mean_coef2 = (
         | 
| 227 | 
            +
                        (1.0 - self.alphas_cumprod_prev)
         | 
| 228 | 
            +
                        * np.sqrt(alphas)
         | 
| 229 | 
            +
                        / (1.0 - self.alphas_cumprod)
         | 
| 230 | 
            +
                    )
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                def scale_channels(self, x: th.Tensor) -> th.Tensor:
         | 
| 233 | 
            +
                    """Apply channel-wise scaling."""
         | 
| 234 | 
            +
                    if self.channel_scales is not None:
         | 
| 235 | 
            +
                        x = x * th.from_numpy(self.channel_scales).to(x).reshape(
         | 
| 236 | 
            +
                            [1, -1, *([1] * (len(x.shape) - 2))]
         | 
| 237 | 
            +
                        )
         | 
| 238 | 
            +
                    if self.channel_biases is not None:
         | 
| 239 | 
            +
                        x = x + th.from_numpy(self.channel_biases).to(x).reshape(
         | 
| 240 | 
            +
                            [1, -1, *([1] * (len(x.shape) - 2))]
         | 
| 241 | 
            +
                        )
         | 
| 242 | 
            +
                    return x
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                def unscale_channels(self, x: th.Tensor) -> th.Tensor:
         | 
| 245 | 
            +
                    """Remove channel-wise scaling."""
         | 
| 246 | 
            +
                    if self.channel_biases is not None:
         | 
| 247 | 
            +
                        x = x - th.from_numpy(self.channel_biases).to(x).reshape(
         | 
| 248 | 
            +
                            [1, -1, *([1] * (len(x.shape) - 2))]
         | 
| 249 | 
            +
                        )
         | 
| 250 | 
            +
                    if self.channel_scales is not None:
         | 
| 251 | 
            +
                        x = x / th.from_numpy(self.channel_scales).to(x).reshape(
         | 
| 252 | 
            +
                            [1, -1, *([1] * (len(x.shape) - 2))]
         | 
| 253 | 
            +
                        )
         | 
| 254 | 
            +
                    return x
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                def unscale_out_dict(
         | 
| 257 | 
            +
                    self, out: Dict[str, Union[th.Tensor, Any]]
         | 
| 258 | 
            +
                ) -> Dict[str, Union[th.Tensor, Any]]:
         | 
| 259 | 
            +
                    return {
         | 
| 260 | 
            +
                        k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v)
         | 
| 261 | 
            +
                        for k, v in out.items()
         | 
| 262 | 
            +
                    }
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                def q_posterior_mean_variance(self, x_start, x_t, t):
         | 
| 265 | 
            +
                    """
         | 
| 266 | 
            +
                    Compute the mean and variance of the diffusion posterior:
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                        q(x_{t-1} | x_t, x_0)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    """
         | 
| 271 | 
            +
                    assert x_start.shape == x_t.shape
         | 
| 272 | 
            +
                    posterior_mean = (
         | 
| 273 | 
            +
                        _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
         | 
| 274 | 
            +
                        + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
         | 
| 275 | 
            +
                    )
         | 
| 276 | 
            +
                    posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
         | 
| 277 | 
            +
                    posterior_log_variance_clipped = _extract_into_tensor(
         | 
| 278 | 
            +
                        self.posterior_log_variance_clipped, t, x_t.shape
         | 
| 279 | 
            +
                    )
         | 
| 280 | 
            +
                    assert (
         | 
| 281 | 
            +
                        posterior_mean.shape[0]
         | 
| 282 | 
            +
                        == posterior_variance.shape[0]
         | 
| 283 | 
            +
                        == posterior_log_variance_clipped.shape[0]
         | 
| 284 | 
            +
                        == x_start.shape[0]
         | 
| 285 | 
            +
                    )
         | 
| 286 | 
            +
                    return posterior_mean, posterior_variance, posterior_log_variance_clipped
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                def p_mean_variance(
         | 
| 289 | 
            +
                    self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
         | 
| 290 | 
            +
                ):
         | 
| 291 | 
            +
                    """
         | 
| 292 | 
            +
                    Apply the model to get p(x_{t-1} | x_t).
         | 
| 293 | 
            +
                    """
         | 
| 294 | 
            +
                    if model_kwargs is None:
         | 
| 295 | 
            +
                        model_kwargs = {}
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    B, C = x.shape[:2]
         | 
| 298 | 
            +
                    assert t.shape == (B,)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    # Direct prediction of eps
         | 
| 301 | 
            +
                    model_output = model(x, t, **model_kwargs)
         | 
| 302 | 
            +
                    if isinstance(model_output, tuple):
         | 
| 303 | 
            +
                        model_output, prev_latent = model_output
         | 
| 304 | 
            +
                        model_kwargs["prev_latent"] = prev_latent
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    # Convert model output to mean and variance
         | 
| 307 | 
            +
                    model_variance, model_log_variance = {
         | 
| 308 | 
            +
                        # for fixedlarge, we set the initial (log-)variance like so
         | 
| 309 | 
            +
                        # to get a better decoder log likelihood.
         | 
| 310 | 
            +
                        "fixed_large": (
         | 
| 311 | 
            +
                            np.append(self.posterior_variance[1], self.betas[1:]),
         | 
| 312 | 
            +
                            np.log(np.append(self.posterior_variance[1], self.betas[1:])),
         | 
| 313 | 
            +
                        ),
         | 
| 314 | 
            +
                        "fixed_small": (
         | 
| 315 | 
            +
                            self.posterior_variance,
         | 
| 316 | 
            +
                            self.posterior_log_variance_clipped,
         | 
| 317 | 
            +
                        ),
         | 
| 318 | 
            +
                    }[self.model_var_type]
         | 
| 319 | 
            +
                    model_variance = _extract_into_tensor(model_variance, t, x.shape)
         | 
| 320 | 
            +
                    model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    def process_xstart(x):
         | 
| 323 | 
            +
                        if denoised_fn is not None:
         | 
| 324 | 
            +
                            x = denoised_fn(x)
         | 
| 325 | 
            +
                        if clip_denoised:
         | 
| 326 | 
            +
                            x = x.clamp(
         | 
| 327 | 
            +
                                -self.channel_scales[0] * 0.67, self.channel_scales[0] * 0.67
         | 
| 328 | 
            +
                            )
         | 
| 329 | 
            +
                            x[:, 3:] = x[:, 3:].clamp(
         | 
| 330 | 
            +
                                -self.channel_scales[3] * 0.5, self.channel_scales[3] * 0.5
         | 
| 331 | 
            +
                            )
         | 
| 332 | 
            +
                            return x
         | 
| 333 | 
            +
                        return x
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    if self.model_mean_type == "x_prev":
         | 
| 336 | 
            +
                        pred_xstart = process_xstart(
         | 
| 337 | 
            +
                            self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
         | 
| 338 | 
            +
                        )
         | 
| 339 | 
            +
                        model_mean = model_output
         | 
| 340 | 
            +
                    elif self.model_mean_type in ["x_start", "epsilon"]:
         | 
| 341 | 
            +
                        if self.model_mean_type == "x_start":
         | 
| 342 | 
            +
                            pred_xstart = process_xstart(model_output)
         | 
| 343 | 
            +
                        else:
         | 
| 344 | 
            +
                            pred_xstart = process_xstart(
         | 
| 345 | 
            +
                                self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
         | 
| 346 | 
            +
                            )
         | 
| 347 | 
            +
                        model_mean, _, _ = self.q_posterior_mean_variance(
         | 
| 348 | 
            +
                            x_start=pred_xstart, x_t=x, t=t
         | 
| 349 | 
            +
                        )
         | 
| 350 | 
            +
                        # print('p_mean_variance:', pred_xstart.min(), pred_xstart.max())
         | 
| 351 | 
            +
                    else:
         | 
| 352 | 
            +
                        raise NotImplementedError(self.model_mean_type)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    assert (
         | 
| 355 | 
            +
                        model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
         | 
| 356 | 
            +
                    )
         | 
| 357 | 
            +
                    return {
         | 
| 358 | 
            +
                        "mean": model_mean,
         | 
| 359 | 
            +
                        "variance": model_variance,
         | 
| 360 | 
            +
                        "log_variance": model_log_variance,
         | 
| 361 | 
            +
                        "pred_xstart": pred_xstart,
         | 
| 362 | 
            +
                    }
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                def _predict_xstart_from_eps(self, x_t, t, eps):
         | 
| 365 | 
            +
                    assert x_t.shape == eps.shape
         | 
| 366 | 
            +
                    return (
         | 
| 367 | 
            +
                        _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
         | 
| 368 | 
            +
                        - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
         | 
| 369 | 
            +
                    )
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                def _predict_xstart_from_xprev(self, x_t, t, xprev):
         | 
| 372 | 
            +
                    assert x_t.shape == xprev.shape
         | 
| 373 | 
            +
                    return (  # (xprev - coef2*x_t) / coef1
         | 
| 374 | 
            +
                        _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
         | 
| 375 | 
            +
                        - _extract_into_tensor(
         | 
| 376 | 
            +
                            self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
         | 
| 377 | 
            +
                        )
         | 
| 378 | 
            +
                        * x_t
         | 
| 379 | 
            +
                    )
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
         | 
| 382 | 
            +
                    return (
         | 
| 383 | 
            +
                        _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
         | 
| 384 | 
            +
                        - pred_xstart
         | 
| 385 | 
            +
                    ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                def ddim_sample_loop_progressive(
         | 
| 388 | 
            +
                    self,
         | 
| 389 | 
            +
                    model,
         | 
| 390 | 
            +
                    shape,
         | 
| 391 | 
            +
                    noise=None,
         | 
| 392 | 
            +
                    clip_denoised=True,
         | 
| 393 | 
            +
                    denoised_fn=None,
         | 
| 394 | 
            +
                    model_kwargs=None,
         | 
| 395 | 
            +
                    device=None,
         | 
| 396 | 
            +
                    progress=False,
         | 
| 397 | 
            +
                    eta=0.0,
         | 
| 398 | 
            +
                ):
         | 
| 399 | 
            +
                    """
         | 
| 400 | 
            +
                    Use DDIM to sample from the model and yield intermediate samples.
         | 
| 401 | 
            +
                    """
         | 
| 402 | 
            +
                    if device is None:
         | 
| 403 | 
            +
                        device = next(model.parameters()).device
         | 
| 404 | 
            +
                    assert isinstance(shape, (tuple, list))
         | 
| 405 | 
            +
                    if noise is not None:
         | 
| 406 | 
            +
                        img = noise
         | 
| 407 | 
            +
                    else:
         | 
| 408 | 
            +
                        img = th.randn(*shape, device=device)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    indices = list(range(self.num_timesteps))[::-1]
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    if progress:
         | 
| 413 | 
            +
                        from tqdm.auto import tqdm
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                        indices = tqdm(indices)
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                    for i in indices:
         | 
| 418 | 
            +
                        t = th.tensor([i] * shape[0], device=device)
         | 
| 419 | 
            +
                        with th.no_grad():
         | 
| 420 | 
            +
                            out = self.ddim_sample(
         | 
| 421 | 
            +
                                model,
         | 
| 422 | 
            +
                                img,
         | 
| 423 | 
            +
                                t,
         | 
| 424 | 
            +
                                clip_denoised=clip_denoised,
         | 
| 425 | 
            +
                                denoised_fn=denoised_fn,
         | 
| 426 | 
            +
                                model_kwargs=model_kwargs,
         | 
| 427 | 
            +
                                eta=eta,
         | 
| 428 | 
            +
                            )
         | 
| 429 | 
            +
                            yield self.unscale_out_dict(out)
         | 
| 430 | 
            +
                            img = out["sample"]
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
         | 
| 433 | 
            +
                    return (
         | 
| 434 | 
            +
                        _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
         | 
| 435 | 
            +
                        - pred_xstart
         | 
| 436 | 
            +
                    ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                def ddim_sample(
         | 
| 439 | 
            +
                    self,
         | 
| 440 | 
            +
                    model,
         | 
| 441 | 
            +
                    x,
         | 
| 442 | 
            +
                    t,
         | 
| 443 | 
            +
                    clip_denoised=True,
         | 
| 444 | 
            +
                    denoised_fn=None,
         | 
| 445 | 
            +
                    model_kwargs=None,
         | 
| 446 | 
            +
                    eta=0.0,
         | 
| 447 | 
            +
                ):
         | 
| 448 | 
            +
                    """
         | 
| 449 | 
            +
                    Sample x_{t-1} from the model using DDIM.
         | 
| 450 | 
            +
                    """
         | 
| 451 | 
            +
                    out = self.p_mean_variance(
         | 
| 452 | 
            +
                        model,
         | 
| 453 | 
            +
                        x,
         | 
| 454 | 
            +
                        t,
         | 
| 455 | 
            +
                        clip_denoised=clip_denoised,
         | 
| 456 | 
            +
                        denoised_fn=denoised_fn,
         | 
| 457 | 
            +
                        model_kwargs=model_kwargs,
         | 
| 458 | 
            +
                    )
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    # Usually our model outputs epsilon, but we re-derive it
         | 
| 461 | 
            +
                    # in case we used x_start or x_prev prediction.
         | 
| 462 | 
            +
                    eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
         | 
| 465 | 
            +
                    alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
         | 
| 466 | 
            +
                    sigma = (
         | 
| 467 | 
            +
                        eta
         | 
| 468 | 
            +
                        * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
         | 
| 469 | 
            +
                        * th.sqrt(1 - alpha_bar / alpha_bar_prev)
         | 
| 470 | 
            +
                    )
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                    # Equation 12.
         | 
| 473 | 
            +
                    noise = th.randn_like(x)
         | 
| 474 | 
            +
                    mean_pred = (
         | 
| 475 | 
            +
                        out["pred_xstart"] * th.sqrt(alpha_bar_prev)
         | 
| 476 | 
            +
                        + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
         | 
| 477 | 
            +
                    )
         | 
| 478 | 
            +
                    nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
         | 
| 479 | 
            +
                    sample = mean_pred + nonzero_mask * sigma * noise
         | 
| 480 | 
            +
                    return {"sample": sample, "pred_xstart": out["pred_xstart"]}
         | 
| 481 | 
            +
             | 
| 482 | 
            +
             | 
| 483 | 
            +
            class SpacedDiffusion(GaussianDiffusion):
         | 
| 484 | 
            +
                """
         | 
| 485 | 
            +
                A diffusion process which can skip steps in a base diffusion process.
         | 
| 486 | 
            +
                """
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                def __init__(self, use_timesteps: Iterable[int], **kwargs):
         | 
| 489 | 
            +
                    self.use_timesteps = set(use_timesteps)
         | 
| 490 | 
            +
                    self.timestep_map = []
         | 
| 491 | 
            +
                    self.original_num_steps = len(kwargs["betas"])
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                    base_diffusion = GaussianDiffusion(**kwargs)
         | 
| 494 | 
            +
                    last_alpha_cumprod = 1.0
         | 
| 495 | 
            +
                    new_betas = []
         | 
| 496 | 
            +
                    for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
         | 
| 497 | 
            +
                        if i in self.use_timesteps:
         | 
| 498 | 
            +
                            new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
         | 
| 499 | 
            +
                            last_alpha_cumprod = alpha_cumprod
         | 
| 500 | 
            +
                            self.timestep_map.append(i)
         | 
| 501 | 
            +
                    kwargs["betas"] = np.array(new_betas)
         | 
| 502 | 
            +
                    super().__init__(**kwargs)
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                def p_mean_variance(self, model, *args, **kwargs):
         | 
| 505 | 
            +
                    return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                def _wrap_model(self, model):
         | 
| 508 | 
            +
                    if isinstance(model, _WrappedModel):
         | 
| 509 | 
            +
                        return model
         | 
| 510 | 
            +
                    return _WrappedModel(model, self.timestep_map, self.original_num_steps)
         | 
| 511 | 
            +
             | 
| 512 | 
            +
             | 
| 513 | 
            +
            class _WrappedModel:
         | 
| 514 | 
            +
                """Helper class to wrap models for SpacedDiffusion."""
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                def __init__(self, model, timestep_map, original_num_steps):
         | 
| 517 | 
            +
                    self.model = model
         | 
| 518 | 
            +
                    self.timestep_map = timestep_map
         | 
| 519 | 
            +
                    self.original_num_steps = original_num_steps
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                def __call__(self, x, ts, **kwargs):
         | 
| 522 | 
            +
                    map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
         | 
| 523 | 
            +
                    new_ts = map_tensor[ts]
         | 
| 524 | 
            +
                    return self.model(x, new_ts, **kwargs)
         | 
    	
        spar3d/models/diffusion/sampler.py
    ADDED
    
    | @@ -0,0 +1,134 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # --------------------------------------------------------
         | 
| 2 | 
            +
            # Adapted from: https://github.com/openai/point-e
         | 
| 3 | 
            +
            # Licensed under the MIT License
         | 
| 4 | 
            +
            # Copyright (c) 2022 OpenAI
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 7 | 
            +
            # of this software and associated documentation files (the "Software"), to deal
         | 
| 8 | 
            +
            # in the Software without restriction, including without limitation the rights
         | 
| 9 | 
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 10 | 
            +
            # copies of the Software, and to permit persons to whom the Software is
         | 
| 11 | 
            +
            # furnished to do so, subject to the following conditions:
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # The above copyright notice and this permission notice shall be included in all
         | 
| 14 | 
            +
            # copies or substantial portions of the Software.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 17 | 
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 18 | 
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 19 | 
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 20 | 
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 21 | 
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 22 | 
            +
            # SOFTWARE.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            # --------------------------------------------------------
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from typing import Dict, Iterator
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            import torch
         | 
| 29 | 
            +
            import torch.nn as nn
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            from .gaussian_diffusion import GaussianDiffusion
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            class PointCloudSampler:
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                A wrapper around a model that produces conditional sample tensors.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def __init__(
         | 
| 40 | 
            +
                    self,
         | 
| 41 | 
            +
                    model: nn.Module,
         | 
| 42 | 
            +
                    diffusion: GaussianDiffusion,
         | 
| 43 | 
            +
                    num_points: int,
         | 
| 44 | 
            +
                    point_dim: int = 3,
         | 
| 45 | 
            +
                    guidance_scale: float = 3.0,
         | 
| 46 | 
            +
                    clip_denoised: bool = True,
         | 
| 47 | 
            +
                    sigma_min: float = 1e-3,
         | 
| 48 | 
            +
                    sigma_max: float = 120,
         | 
| 49 | 
            +
                    s_churn: float = 3,
         | 
| 50 | 
            +
                ):
         | 
| 51 | 
            +
                    self.model = model
         | 
| 52 | 
            +
                    self.num_points = num_points
         | 
| 53 | 
            +
                    self.point_dim = point_dim
         | 
| 54 | 
            +
                    self.guidance_scale = guidance_scale
         | 
| 55 | 
            +
                    self.clip_denoised = clip_denoised
         | 
| 56 | 
            +
                    self.sigma_min = sigma_min
         | 
| 57 | 
            +
                    self.sigma_max = sigma_max
         | 
| 58 | 
            +
                    self.s_churn = s_churn
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    self.diffusion = diffusion
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def sample_batch_progressive(
         | 
| 63 | 
            +
                    self,
         | 
| 64 | 
            +
                    batch_size: int,
         | 
| 65 | 
            +
                    condition: torch.Tensor,
         | 
| 66 | 
            +
                    noise=None,
         | 
| 67 | 
            +
                    device=None,
         | 
| 68 | 
            +
                    guidance_scale=None,
         | 
| 69 | 
            +
                ) -> Iterator[Dict[str, torch.Tensor]]:
         | 
| 70 | 
            +
                    """
         | 
| 71 | 
            +
                    Generate samples progressively using classifier-free guidance.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    Args:
         | 
| 74 | 
            +
                        batch_size: Number of samples to generate
         | 
| 75 | 
            +
                        condition: Conditioning tensor
         | 
| 76 | 
            +
                        noise: Optional initial noise tensor
         | 
| 77 | 
            +
                        device: Device to run on
         | 
| 78 | 
            +
                        guidance_scale: Optional override for guidance scale
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    Returns:
         | 
| 81 | 
            +
                        Iterator of dicts containing intermediate samples
         | 
| 82 | 
            +
                    """
         | 
| 83 | 
            +
                    if guidance_scale is None:
         | 
| 84 | 
            +
                        guidance_scale = self.guidance_scale
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    sample_shape = (batch_size, self.point_dim, self.num_points)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Double the batch for classifier-free guidance
         | 
| 89 | 
            +
                    if guidance_scale != 1 and guidance_scale != 0:
         | 
| 90 | 
            +
                        condition = torch.cat([condition, torch.zeros_like(condition)], dim=0)
         | 
| 91 | 
            +
                        if noise is not None:
         | 
| 92 | 
            +
                            noise = torch.cat([noise, noise], dim=0)
         | 
| 93 | 
            +
                    model_kwargs = {"condition": condition}
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    internal_batch_size = batch_size
         | 
| 96 | 
            +
                    if guidance_scale != 1 and guidance_scale != 0:
         | 
| 97 | 
            +
                        model = self._uncond_guide_model(self.model, guidance_scale)
         | 
| 98 | 
            +
                        internal_batch_size *= 2
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        model = self.model
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    samples_it = self.diffusion.ddim_sample_loop_progressive(
         | 
| 103 | 
            +
                        model,
         | 
| 104 | 
            +
                        shape=(internal_batch_size, *sample_shape[1:]),
         | 
| 105 | 
            +
                        model_kwargs=model_kwargs,
         | 
| 106 | 
            +
                        device=device,
         | 
| 107 | 
            +
                        clip_denoised=self.clip_denoised,
         | 
| 108 | 
            +
                        noise=noise,
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    for x in samples_it:
         | 
| 112 | 
            +
                        samples = {
         | 
| 113 | 
            +
                            "xstart": x["pred_xstart"][:batch_size],
         | 
| 114 | 
            +
                            "xprev": x["sample"][:batch_size] if "sample" in x else x["x"],
         | 
| 115 | 
            +
                        }
         | 
| 116 | 
            +
                        yield samples
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def _uncond_guide_model(self, model: nn.Module, scale: float) -> nn.Module:
         | 
| 119 | 
            +
                    """
         | 
| 120 | 
            +
                    Wraps the model for classifier-free guidance.
         | 
| 121 | 
            +
                    """
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    def model_fn(x_t, ts, **kwargs):
         | 
| 124 | 
            +
                        half = x_t[: len(x_t) // 2]
         | 
| 125 | 
            +
                        combined = torch.cat([half, half], dim=0)
         | 
| 126 | 
            +
                        model_out = model(combined, ts, **kwargs)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                        eps, rest = model_out[:, : self.point_dim], model_out[:, self.point_dim :]
         | 
| 129 | 
            +
                        cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
         | 
| 130 | 
            +
                        half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
         | 
| 131 | 
            +
                        eps = torch.cat([half_eps, half_eps], dim=0)
         | 
| 132 | 
            +
                        return torch.cat([eps, rest], dim=1)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    return model_fn
         | 
    	
        spar3d/models/global_estimator/reni_estimator.py
    ADDED
    
    | @@ -0,0 +1,112 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass, field
         | 
| 2 | 
            +
            from typing import Any
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            from jaxtyping import Float
         | 
| 8 | 
            +
            from torch import Tensor
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from spar3d.models.illumination.reni.env_map import RENIEnvMap
         | 
| 11 | 
            +
            from spar3d.models.utils import BaseModule
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
         | 
| 15 | 
            +
                assert d6.shape[-1] == 6, "Input tensor must have shape (..., 6)"
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def proj_u2a(u, a):
         | 
| 18 | 
            +
                    r"""
         | 
| 19 | 
            +
                    u: batch x 3
         | 
| 20 | 
            +
                    a: batch x 3
         | 
| 21 | 
            +
                    """
         | 
| 22 | 
            +
                    inner_prod = torch.sum(u * a, dim=-1, keepdim=True)
         | 
| 23 | 
            +
                    norm2 = torch.sum(u**2, dim=-1, keepdim=True)
         | 
| 24 | 
            +
                    norm2 = torch.clamp(norm2, min=1e-8)
         | 
| 25 | 
            +
                    factor = inner_prod / (norm2 + 1e-10)
         | 
| 26 | 
            +
                    return factor * u
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                x_raw, y_raw = d6[..., :3], d6[..., 3:]
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                x = F.normalize(x_raw, dim=-1)
         | 
| 31 | 
            +
                y = F.normalize(y_raw - proj_u2a(x, y_raw), dim=-1)
         | 
| 32 | 
            +
                z = torch.cross(x, y, dim=-1)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                return torch.stack((x, y, z), dim=-1)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class ReniLatentCodeEstimator(BaseModule):
         | 
| 38 | 
            +
                @dataclass
         | 
| 39 | 
            +
                class Config(BaseModule.Config):
         | 
| 40 | 
            +
                    triplane_features: int = 40
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    n_layers: int = 5
         | 
| 43 | 
            +
                    hidden_features: int = 512
         | 
| 44 | 
            +
                    activation: str = "relu"
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    pool: str = "mean"
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    reni_env_config: dict = field(default_factory=dict)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                cfg: Config
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def configure(self):
         | 
| 53 | 
            +
                    layers = []
         | 
| 54 | 
            +
                    cur_features = self.cfg.triplane_features * 3
         | 
| 55 | 
            +
                    for _ in range(self.cfg.n_layers):
         | 
| 56 | 
            +
                        layers.append(
         | 
| 57 | 
            +
                            nn.Conv2d(
         | 
| 58 | 
            +
                                cur_features,
         | 
| 59 | 
            +
                                self.cfg.hidden_features,
         | 
| 60 | 
            +
                                kernel_size=3,
         | 
| 61 | 
            +
                                padding=0,
         | 
| 62 | 
            +
                                stride=2,
         | 
| 63 | 
            +
                            )
         | 
| 64 | 
            +
                        )
         | 
| 65 | 
            +
                        layers.append(self.make_activation(self.cfg.activation))
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                        cur_features = self.cfg.hidden_features
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    self.layers = nn.Sequential(*layers)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.reni_env_map = RENIEnvMap(self.cfg.reni_env_config)
         | 
| 72 | 
            +
                    self.latent_dim = self.reni_env_map.field.latent_dim
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    self.fc_latents = nn.Linear(self.cfg.hidden_features, self.latent_dim * 3)
         | 
| 75 | 
            +
                    nn.init.normal_(self.fc_latents.weight, mean=0.0, std=0.3)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    self.fc_rotations = nn.Linear(self.cfg.hidden_features, 6)
         | 
| 78 | 
            +
                    nn.init.constant_(self.fc_rotations.bias, 0.0)
         | 
| 79 | 
            +
                    nn.init.normal_(
         | 
| 80 | 
            +
                        self.fc_rotations.weight, mean=0.0, std=0.01
         | 
| 81 | 
            +
                    )  # Small variance here
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    self.fc_scale = nn.Linear(self.cfg.hidden_features, 1)
         | 
| 84 | 
            +
                    nn.init.constant_(self.fc_scale.bias, 0.0)
         | 
| 85 | 
            +
                    nn.init.normal_(self.fc_scale.weight, mean=0.0, std=0.01)  # Small variance here
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def make_activation(self, activation):
         | 
| 88 | 
            +
                    if activation == "relu":
         | 
| 89 | 
            +
                        return nn.ReLU(inplace=True)
         | 
| 90 | 
            +
                    elif activation == "silu":
         | 
| 91 | 
            +
                        return nn.SiLU(inplace=True)
         | 
| 92 | 
            +
                    else:
         | 
| 93 | 
            +
                        raise NotImplementedError
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def forward(
         | 
| 96 | 
            +
                    self,
         | 
| 97 | 
            +
                    triplane: Float[Tensor, "B 3 F Ht Wt"],
         | 
| 98 | 
            +
                ) -> dict[str, Any]:
         | 
| 99 | 
            +
                    x = self.layers(
         | 
| 100 | 
            +
                        triplane.reshape(
         | 
| 101 | 
            +
                            triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
         | 
| 102 | 
            +
                        )
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
                    x = x.mean(dim=[-2, -1])
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    latents = self.fc_latents(x).reshape(-1, self.latent_dim, 3)
         | 
| 107 | 
            +
                    rotations = self.fc_rotations(x)
         | 
| 108 | 
            +
                    scale = self.fc_scale(x)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    env_map = self.reni_env_map(latents, rotation_6d_to_matrix(rotations), scale)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    return {"illumination": env_map["rgb"]}
         | 
    	
        spar3d/models/illumination/reni/components/film_siren.py
    ADDED
    
    | @@ -0,0 +1,148 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """FiLM Siren MLP as per https://marcoamonteiro.github.io/pi-GAN-website/."""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from typing import Optional
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from torch import nn
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def kaiming_leaky_init(m):
         | 
| 11 | 
            +
                classname = m.__class__.__name__
         | 
| 12 | 
            +
                if classname.find("Linear") != -1:
         | 
| 13 | 
            +
                    torch.nn.init.kaiming_normal_(
         | 
| 14 | 
            +
                        m.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu"
         | 
| 15 | 
            +
                    )
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def frequency_init(freq):
         | 
| 19 | 
            +
                def init(m):
         | 
| 20 | 
            +
                    with torch.no_grad():
         | 
| 21 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 22 | 
            +
                            num_input = m.weight.size(-1)
         | 
| 23 | 
            +
                            m.weight.uniform_(
         | 
| 24 | 
            +
                                -np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq
         | 
| 25 | 
            +
                            )
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                return init
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def first_layer_film_sine_init(m):
         | 
| 31 | 
            +
                with torch.no_grad():
         | 
| 32 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 33 | 
            +
                        num_input = m.weight.size(-1)
         | 
| 34 | 
            +
                        m.weight.uniform_(-1 / num_input, 1 / num_input)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class CustomMappingNetwork(nn.Module):
         | 
| 38 | 
            +
                def __init__(self, in_features, map_hidden_layers, map_hidden_dim, map_output_dim):
         | 
| 39 | 
            +
                    super().__init__()
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    self.network = []
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    for _ in range(map_hidden_layers):
         | 
| 44 | 
            +
                        self.network.append(nn.Linear(in_features, map_hidden_dim))
         | 
| 45 | 
            +
                        self.network.append(nn.LeakyReLU(0.2, inplace=True))
         | 
| 46 | 
            +
                        in_features = map_hidden_dim
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    self.network.append(nn.Linear(map_hidden_dim, map_output_dim))
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    self.network = nn.Sequential(*self.network)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    self.network.apply(kaiming_leaky_init)
         | 
| 53 | 
            +
                    with torch.no_grad():
         | 
| 54 | 
            +
                        self.network[-1].weight *= 0.25
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def forward(self, z):
         | 
| 57 | 
            +
                    frequencies_offsets = self.network(z)
         | 
| 58 | 
            +
                    frequencies = frequencies_offsets[
         | 
| 59 | 
            +
                        ..., : torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor")
         | 
| 60 | 
            +
                    ]
         | 
| 61 | 
            +
                    phase_shifts = frequencies_offsets[
         | 
| 62 | 
            +
                        ..., torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") :
         | 
| 63 | 
            +
                    ]
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    return frequencies, phase_shifts
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            class FiLMLayer(nn.Module):
         | 
| 69 | 
            +
                def __init__(self, input_dim, hidden_dim):
         | 
| 70 | 
            +
                    super().__init__()
         | 
| 71 | 
            +
                    self.layer = nn.Linear(input_dim, hidden_dim)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def forward(self, x, freq, phase_shift):
         | 
| 74 | 
            +
                    x = self.layer(x)
         | 
| 75 | 
            +
                    freq = freq.expand_as(x)
         | 
| 76 | 
            +
                    phase_shift = phase_shift.expand_as(x)
         | 
| 77 | 
            +
                    return torch.sin(freq * x + phase_shift)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            class FiLMSiren(nn.Module):
         | 
| 81 | 
            +
                """FiLM Conditioned Siren network."""
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                def __init__(
         | 
| 84 | 
            +
                    self,
         | 
| 85 | 
            +
                    in_dim: int,
         | 
| 86 | 
            +
                    hidden_layers: int,
         | 
| 87 | 
            +
                    hidden_features: int,
         | 
| 88 | 
            +
                    mapping_network_in_dim: int,
         | 
| 89 | 
            +
                    mapping_network_layers: int,
         | 
| 90 | 
            +
                    mapping_network_features: int,
         | 
| 91 | 
            +
                    out_dim: int,
         | 
| 92 | 
            +
                    outermost_linear: bool = False,
         | 
| 93 | 
            +
                    out_activation: Optional[nn.Module] = None,
         | 
| 94 | 
            +
                ) -> None:
         | 
| 95 | 
            +
                    super().__init__()
         | 
| 96 | 
            +
                    self.in_dim = in_dim
         | 
| 97 | 
            +
                    assert self.in_dim > 0
         | 
| 98 | 
            +
                    self.out_dim = out_dim if out_dim is not None else hidden_features
         | 
| 99 | 
            +
                    self.hidden_layers = hidden_layers
         | 
| 100 | 
            +
                    self.hidden_features = hidden_features
         | 
| 101 | 
            +
                    self.mapping_network_in_dim = mapping_network_in_dim
         | 
| 102 | 
            +
                    self.mapping_network_layers = mapping_network_layers
         | 
| 103 | 
            +
                    self.mapping_network_features = mapping_network_features
         | 
| 104 | 
            +
                    self.outermost_linear = outermost_linear
         | 
| 105 | 
            +
                    self.out_activation = out_activation
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    self.net = nn.ModuleList()
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    self.net.append(FiLMLayer(self.in_dim, self.hidden_features))
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    for _ in range(self.hidden_layers - 1):
         | 
| 112 | 
            +
                        self.net.append(FiLMLayer(self.hidden_features, self.hidden_features))
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    self.final_layer = None
         | 
| 115 | 
            +
                    if self.outermost_linear:
         | 
| 116 | 
            +
                        self.final_layer = nn.Linear(self.hidden_features, self.out_dim)
         | 
| 117 | 
            +
                        self.final_layer.apply(frequency_init(25))
         | 
| 118 | 
            +
                    else:
         | 
| 119 | 
            +
                        final_layer = FiLMLayer(self.hidden_features, self.out_dim)
         | 
| 120 | 
            +
                        self.net.append(final_layer)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    self.mapping_network = CustomMappingNetwork(
         | 
| 123 | 
            +
                        in_features=self.mapping_network_in_dim,
         | 
| 124 | 
            +
                        map_hidden_layers=self.mapping_network_layers,
         | 
| 125 | 
            +
                        map_hidden_dim=self.mapping_network_features,
         | 
| 126 | 
            +
                        map_output_dim=(len(self.net)) * self.hidden_features * 2,
         | 
| 127 | 
            +
                    )
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    self.net.apply(frequency_init(25))
         | 
| 130 | 
            +
                    self.net[0].apply(first_layer_film_sine_init)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def forward_with_frequencies_phase_shifts(self, x, frequencies, phase_shifts):
         | 
| 133 | 
            +
                    """Get conditiional frequencies and phase shifts from mapping network."""
         | 
| 134 | 
            +
                    frequencies = frequencies * 15 + 30
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    for index, layer in enumerate(self.net):
         | 
| 137 | 
            +
                        start = index * self.hidden_features
         | 
| 138 | 
            +
                        end = (index + 1) * self.hidden_features
         | 
| 139 | 
            +
                        x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end])
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    x = self.final_layer(x) if self.final_layer is not None else x
         | 
| 142 | 
            +
                    output = self.out_activation(x) if self.out_activation is not None else x
         | 
| 143 | 
            +
                    return output
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def forward(self, x, conditioning_input):
         | 
| 146 | 
            +
                    """Forward pass."""
         | 
| 147 | 
            +
                    frequencies, phase_shifts = self.mapping_network(conditioning_input)
         | 
| 148 | 
            +
                    return self.forward_with_frequencies_phase_shifts(x, frequencies, phase_shifts)
         | 
    	
        spar3d/models/illumination/reni/components/siren.py
    ADDED
    
    | @@ -0,0 +1,118 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Siren MLP https://www.vincentsitzmann.com/siren/"""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from typing import Optional
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from torch import nn
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class SineLayer(nn.Module):
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                Sine layer for the SIREN network.
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def __init__(
         | 
| 16 | 
            +
                    self, in_features, out_features, bias=True, is_first=False, omega_0=30.0
         | 
| 17 | 
            +
                ):
         | 
| 18 | 
            +
                    super().__init__()
         | 
| 19 | 
            +
                    self.omega_0 = omega_0
         | 
| 20 | 
            +
                    self.is_first = is_first
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    self.in_features = in_features
         | 
| 23 | 
            +
                    self.linear = nn.Linear(in_features, out_features, bias=bias)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    self.init_weights()
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def init_weights(self):
         | 
| 28 | 
            +
                    with torch.no_grad():
         | 
| 29 | 
            +
                        if self.is_first:
         | 
| 30 | 
            +
                            self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
         | 
| 31 | 
            +
                        else:
         | 
| 32 | 
            +
                            self.linear.weight.uniform_(
         | 
| 33 | 
            +
                                -np.sqrt(6 / self.in_features) / self.omega_0,
         | 
| 34 | 
            +
                                np.sqrt(6 / self.in_features) / self.omega_0,
         | 
| 35 | 
            +
                            )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def forward(self, x):
         | 
| 38 | 
            +
                    return torch.sin(self.omega_0 * self.linear(x))
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            class Siren(nn.Module):
         | 
| 42 | 
            +
                """Siren network.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                Args:
         | 
| 45 | 
            +
                    in_dim: Input layer dimension
         | 
| 46 | 
            +
                    num_layers: Number of network layers
         | 
| 47 | 
            +
                    layer_width: Width of each MLP layer
         | 
| 48 | 
            +
                    out_dim: Output layer dimension. Uses layer_width if None.
         | 
| 49 | 
            +
                    activation: intermediate layer activation function.
         | 
| 50 | 
            +
                    out_activation: output activation function.
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def __init__(
         | 
| 54 | 
            +
                    self,
         | 
| 55 | 
            +
                    in_dim: int,
         | 
| 56 | 
            +
                    hidden_layers: int,
         | 
| 57 | 
            +
                    hidden_features: int,
         | 
| 58 | 
            +
                    out_dim: Optional[int] = None,
         | 
| 59 | 
            +
                    outermost_linear: bool = False,
         | 
| 60 | 
            +
                    first_omega_0: float = 30,
         | 
| 61 | 
            +
                    hidden_omega_0: float = 30,
         | 
| 62 | 
            +
                    out_activation: Optional[nn.Module] = None,
         | 
| 63 | 
            +
                ) -> None:
         | 
| 64 | 
            +
                    super().__init__()
         | 
| 65 | 
            +
                    self.in_dim = in_dim
         | 
| 66 | 
            +
                    assert self.in_dim > 0
         | 
| 67 | 
            +
                    self.out_dim = out_dim if out_dim is not None else hidden_features
         | 
| 68 | 
            +
                    self.outermost_linear = outermost_linear
         | 
| 69 | 
            +
                    self.first_omega_0 = first_omega_0
         | 
| 70 | 
            +
                    self.hidden_omega_0 = hidden_omega_0
         | 
| 71 | 
            +
                    self.hidden_layers = hidden_layers
         | 
| 72 | 
            +
                    self.layer_width = hidden_features
         | 
| 73 | 
            +
                    self.out_activation = out_activation
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    self.net = []
         | 
| 76 | 
            +
                    self.net.append(
         | 
| 77 | 
            +
                        SineLayer(in_dim, hidden_features, is_first=True, omega_0=first_omega_0)
         | 
| 78 | 
            +
                    )
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    for _ in range(hidden_layers):
         | 
| 81 | 
            +
                        self.net.append(
         | 
| 82 | 
            +
                            SineLayer(
         | 
| 83 | 
            +
                                hidden_features,
         | 
| 84 | 
            +
                                hidden_features,
         | 
| 85 | 
            +
                                is_first=False,
         | 
| 86 | 
            +
                                omega_0=hidden_omega_0,
         | 
| 87 | 
            +
                            )
         | 
| 88 | 
            +
                        )
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    if outermost_linear:
         | 
| 91 | 
            +
                        final_layer = nn.Linear(hidden_features, self.out_dim)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                        with torch.no_grad():
         | 
| 94 | 
            +
                            final_layer.weight.uniform_(
         | 
| 95 | 
            +
                                -np.sqrt(6 / hidden_features) / hidden_omega_0,
         | 
| 96 | 
            +
                                np.sqrt(6 / hidden_features) / hidden_omega_0,
         | 
| 97 | 
            +
                            )
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                        self.net.append(final_layer)
         | 
| 100 | 
            +
                    else:
         | 
| 101 | 
            +
                        self.net.append(
         | 
| 102 | 
            +
                            SineLayer(
         | 
| 103 | 
            +
                                hidden_features,
         | 
| 104 | 
            +
                                self.out_dim,
         | 
| 105 | 
            +
                                is_first=False,
         | 
| 106 | 
            +
                                omega_0=hidden_omega_0,
         | 
| 107 | 
            +
                            )
         | 
| 108 | 
            +
                        )
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    if self.out_activation is not None:
         | 
| 111 | 
            +
                        self.net.append(self.out_activation)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    self.net = nn.Sequential(*self.net)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def forward(self, model_input):
         | 
| 116 | 
            +
                    """Forward pass through the network"""
         | 
| 117 | 
            +
                    output = self.net(model_input)
         | 
| 118 | 
            +
                    return output
         | 
    	
        spar3d/models/illumination/reni/components/transformer_decoder.py
    ADDED
    
    | @@ -0,0 +1,189 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Optional
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch import nn
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class MultiHeadAttention(nn.Module):
         | 
| 8 | 
            +
                def __init__(
         | 
| 9 | 
            +
                    self,
         | 
| 10 | 
            +
                    direction_input_dim: int,
         | 
| 11 | 
            +
                    conditioning_input_dim: int,
         | 
| 12 | 
            +
                    latent_dim: int,
         | 
| 13 | 
            +
                    num_heads: int,
         | 
| 14 | 
            +
                ):
         | 
| 15 | 
            +
                    """
         | 
| 16 | 
            +
                    Multi-Head Attention module.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                    Args:
         | 
| 19 | 
            +
                        direction_input_dim (int): The input dimension of the directional input.
         | 
| 20 | 
            +
                        conditioning_input_dim (int): The input dimension of the conditioning input.
         | 
| 21 | 
            +
                        latent_dim (int): The latent dimension of the module.
         | 
| 22 | 
            +
                        num_heads (int): The number of heads to use in the attention mechanism.
         | 
| 23 | 
            +
                    """
         | 
| 24 | 
            +
                    super().__init__()
         | 
| 25 | 
            +
                    assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads"
         | 
| 26 | 
            +
                    self.num_heads = num_heads
         | 
| 27 | 
            +
                    self.head_dim = latent_dim // num_heads
         | 
| 28 | 
            +
                    self.scale = self.head_dim**-0.5
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    self.query = nn.Linear(direction_input_dim, latent_dim)
         | 
| 31 | 
            +
                    self.key = nn.Linear(conditioning_input_dim, latent_dim)
         | 
| 32 | 
            +
                    self.value = nn.Linear(conditioning_input_dim, latent_dim)
         | 
| 33 | 
            +
                    self.fc_out = nn.Linear(latent_dim, latent_dim)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def forward(
         | 
| 36 | 
            +
                    self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
         | 
| 37 | 
            +
                ) -> torch.Tensor:
         | 
| 38 | 
            +
                    """
         | 
| 39 | 
            +
                    Forward pass of the Multi-Head Attention module.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    Args:
         | 
| 42 | 
            +
                        query (torch.Tensor): The directional input tensor.
         | 
| 43 | 
            +
                        key (torch.Tensor): The conditioning input tensor for the keys.
         | 
| 44 | 
            +
                        value (torch.Tensor): The conditioning input tensor for the values.
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    Returns:
         | 
| 47 | 
            +
                        torch.Tensor: The output tensor of the Multi-Head Attention module.
         | 
| 48 | 
            +
                    """
         | 
| 49 | 
            +
                    batch_size = query.size(0)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    Q = (
         | 
| 52 | 
            +
                        self.query(query)
         | 
| 53 | 
            +
                        .view(batch_size, -1, self.num_heads, self.head_dim)
         | 
| 54 | 
            +
                        .transpose(1, 2)
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                    K = (
         | 
| 57 | 
            +
                        self.key(key)
         | 
| 58 | 
            +
                        .view(batch_size, -1, self.num_heads, self.head_dim)
         | 
| 59 | 
            +
                        .transpose(1, 2)
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                    V = (
         | 
| 62 | 
            +
                        self.value(value)
         | 
| 63 | 
            +
                        .view(batch_size, -1, self.num_heads, self.head_dim)
         | 
| 64 | 
            +
                        .transpose(1, 2)
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    attention = (
         | 
| 68 | 
            +
                        torch.einsum("bnqk,bnkh->bnqh", [Q, K.transpose(-2, -1)]) * self.scale
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
                    attention = torch.softmax(attention, dim=-1)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    out = torch.einsum("bnqh,bnhv->bnqv", [attention, V])
         | 
| 73 | 
            +
                    out = (
         | 
| 74 | 
            +
                        out.transpose(1, 2)
         | 
| 75 | 
            +
                        .contiguous()
         | 
| 76 | 
            +
                        .view(batch_size, -1, self.num_heads * self.head_dim)
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    out = self.fc_out(out).squeeze(1)
         | 
| 80 | 
            +
                    return out
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            class AttentionLayer(nn.Module):
         | 
| 84 | 
            +
                def __init__(
         | 
| 85 | 
            +
                    self,
         | 
| 86 | 
            +
                    direction_input_dim: int,
         | 
| 87 | 
            +
                    conditioning_input_dim: int,
         | 
| 88 | 
            +
                    latent_dim: int,
         | 
| 89 | 
            +
                    num_heads: int,
         | 
| 90 | 
            +
                ):
         | 
| 91 | 
            +
                    """
         | 
| 92 | 
            +
                    Attention Layer module.
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    Args:
         | 
| 95 | 
            +
                        direction_input_dim (int): The input dimension of the directional input.
         | 
| 96 | 
            +
                        conditioning_input_dim (int): The input dimension of the conditioning input.
         | 
| 97 | 
            +
                        latent_dim (int): The latent dimension of the module.
         | 
| 98 | 
            +
                        num_heads (int): The number of heads to use in the attention mechanism.
         | 
| 99 | 
            +
                    """
         | 
| 100 | 
            +
                    super().__init__()
         | 
| 101 | 
            +
                    self.mha = MultiHeadAttention(
         | 
| 102 | 
            +
                        direction_input_dim, conditioning_input_dim, latent_dim, num_heads
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
                    self.norm1 = nn.LayerNorm(latent_dim)
         | 
| 105 | 
            +
                    self.norm2 = nn.LayerNorm(latent_dim)
         | 
| 106 | 
            +
                    self.fc = nn.Sequential(
         | 
| 107 | 
            +
                        nn.Linear(latent_dim, latent_dim),
         | 
| 108 | 
            +
                        nn.ReLU(),
         | 
| 109 | 
            +
                        nn.Linear(latent_dim, latent_dim),
         | 
| 110 | 
            +
                    )
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def forward(
         | 
| 113 | 
            +
                    self, directional_input: torch.Tensor, conditioning_input: torch.Tensor
         | 
| 114 | 
            +
                ) -> torch.Tensor:
         | 
| 115 | 
            +
                    """
         | 
| 116 | 
            +
                    Forward pass of the Attention Layer module.
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    Args:
         | 
| 119 | 
            +
                        directional_input (torch.Tensor): The directional input tensor.
         | 
| 120 | 
            +
                        conditioning_input (torch.Tensor): The conditioning input tensor.
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    Returns:
         | 
| 123 | 
            +
                        torch.Tensor: The output tensor of the Attention Layer module.
         | 
| 124 | 
            +
                    """
         | 
| 125 | 
            +
                    attn_output = self.mha(
         | 
| 126 | 
            +
                        directional_input, conditioning_input, conditioning_input
         | 
| 127 | 
            +
                    )
         | 
| 128 | 
            +
                    out1 = self.norm1(attn_output + directional_input)
         | 
| 129 | 
            +
                    fc_output = self.fc(out1)
         | 
| 130 | 
            +
                    out2 = self.norm2(fc_output + out1)
         | 
| 131 | 
            +
                    return out2
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            class Decoder(nn.Module):
         | 
| 135 | 
            +
                def __init__(
         | 
| 136 | 
            +
                    self,
         | 
| 137 | 
            +
                    in_dim: int,
         | 
| 138 | 
            +
                    conditioning_input_dim: int,
         | 
| 139 | 
            +
                    hidden_features: int,
         | 
| 140 | 
            +
                    num_heads: int,
         | 
| 141 | 
            +
                    num_layers: int,
         | 
| 142 | 
            +
                    out_activation: Optional[nn.Module],
         | 
| 143 | 
            +
                ):
         | 
| 144 | 
            +
                    """
         | 
| 145 | 
            +
                    Decoder module.
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    Args:
         | 
| 148 | 
            +
                        in_dim (int): The input dimension of the module.
         | 
| 149 | 
            +
                        conditioning_input_dim (int): The input dimension of the conditioning input.
         | 
| 150 | 
            +
                        hidden_features (int): The number of hidden features in the module.
         | 
| 151 | 
            +
                        num_heads (int): The number of heads to use in the attention mechanism.
         | 
| 152 | 
            +
                        num_layers (int): The number of layers in the module.
         | 
| 153 | 
            +
                        out_activation (nn.Module): The activation function to use on the output tensor.
         | 
| 154 | 
            +
                    """
         | 
| 155 | 
            +
                    super().__init__()
         | 
| 156 | 
            +
                    self.residual_projection = nn.Linear(
         | 
| 157 | 
            +
                        in_dim, hidden_features
         | 
| 158 | 
            +
                    )  # projection for residual connection
         | 
| 159 | 
            +
                    self.layers = nn.ModuleList(
         | 
| 160 | 
            +
                        [
         | 
| 161 | 
            +
                            AttentionLayer(
         | 
| 162 | 
            +
                                hidden_features, conditioning_input_dim, hidden_features, num_heads
         | 
| 163 | 
            +
                            )
         | 
| 164 | 
            +
                            for i in range(num_layers)
         | 
| 165 | 
            +
                        ]
         | 
| 166 | 
            +
                    )
         | 
| 167 | 
            +
                    self.fc = nn.Linear(hidden_features, 3)  # 3 for RGB
         | 
| 168 | 
            +
                    self.out_activation = out_activation
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def forward(
         | 
| 171 | 
            +
                    self, x: torch.Tensor, conditioning_input: torch.Tensor
         | 
| 172 | 
            +
                ) -> torch.Tensor:
         | 
| 173 | 
            +
                    """
         | 
| 174 | 
            +
                    Forward pass of the Decoder module.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    Args:
         | 
| 177 | 
            +
                        x (torch.Tensor): The input tensor.
         | 
| 178 | 
            +
                        conditioning_input (torch.Tensor): The conditioning input tensor.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    Returns:
         | 
| 181 | 
            +
                        torch.Tensor: The output tensor of the Decoder module.
         | 
| 182 | 
            +
                    """
         | 
| 183 | 
            +
                    x = self.residual_projection(x)
         | 
| 184 | 
            +
                    for layer in self.layers:
         | 
| 185 | 
            +
                        x = layer(x, conditioning_input)
         | 
| 186 | 
            +
                    x = self.fc(x)
         | 
| 187 | 
            +
                    if self.out_activation is not None:
         | 
| 188 | 
            +
                        x = self.out_activation(x)
         | 
| 189 | 
            +
                    return x
         | 
    	
        spar3d/models/illumination/reni/components/vn_layers.py
    ADDED
    
    | @@ -0,0 +1,548 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright (c) 2022 Phil Wang
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            # of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            # in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            # copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            # furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            # copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            # SOFTWARE.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            """All code taken from https://github.com/lucidrains/VN-transformer"""
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from collections import namedtuple
         | 
| 26 | 
            +
            from functools import wraps
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            import torch
         | 
| 29 | 
            +
            import torch.nn.functional as F
         | 
| 30 | 
            +
            from einops import rearrange, reduce
         | 
| 31 | 
            +
            from einops.layers.torch import Rearrange
         | 
| 32 | 
            +
            from packaging import version
         | 
| 33 | 
            +
            from torch import einsum, nn
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            # constants
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            FlashAttentionConfig = namedtuple(
         | 
| 38 | 
            +
                "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
         | 
| 39 | 
            +
            )
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            # helpers
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def exists(val):
         | 
| 45 | 
            +
                return val is not None
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def once(fn):
         | 
| 49 | 
            +
                called = False
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                @wraps(fn)
         | 
| 52 | 
            +
                def inner(x):
         | 
| 53 | 
            +
                    nonlocal called
         | 
| 54 | 
            +
                    if called:
         | 
| 55 | 
            +
                        return
         | 
| 56 | 
            +
                    called = True
         | 
| 57 | 
            +
                    return fn(x)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                return inner
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            print_once = once(print)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            # main class
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class Attend(nn.Module):
         | 
| 68 | 
            +
                def __init__(self, dropout=0.0, flash=False, l2_dist=False):
         | 
| 69 | 
            +
                    super().__init__()
         | 
| 70 | 
            +
                    assert not (
         | 
| 71 | 
            +
                        flash and l2_dist
         | 
| 72 | 
            +
                    ), "flash attention is not compatible with l2 distance"
         | 
| 73 | 
            +
                    self.l2_dist = l2_dist
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    self.dropout = dropout
         | 
| 76 | 
            +
                    self.attn_dropout = nn.Dropout(dropout)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    self.flash = flash
         | 
| 79 | 
            +
                    assert not (
         | 
| 80 | 
            +
                        flash and version.parse(torch.__version__) < version.parse("2.0.0")
         | 
| 81 | 
            +
                    ), "in order to use flash attention, you must be using pytorch 2.0 or above"
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    # determine efficient attention configs for cuda and cpu
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    self.cpu_config = FlashAttentionConfig(True, True, True)
         | 
| 86 | 
            +
                    self.cuda_config = None
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    if not torch.cuda.is_available() or not flash:
         | 
| 89 | 
            +
                        return
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    if device_properties.major == 8 and device_properties.minor == 0:
         | 
| 94 | 
            +
                        print_once(
         | 
| 95 | 
            +
                            "A100 GPU detected, using flash attention if input tensor is on cuda"
         | 
| 96 | 
            +
                        )
         | 
| 97 | 
            +
                        self.cuda_config = FlashAttentionConfig(True, False, False)
         | 
| 98 | 
            +
                    else:
         | 
| 99 | 
            +
                        print_once(
         | 
| 100 | 
            +
                            "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
         | 
| 101 | 
            +
                        )
         | 
| 102 | 
            +
                        self.cuda_config = FlashAttentionConfig(False, True, True)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def flash_attn(self, q, k, v, mask=None):
         | 
| 105 | 
            +
                    _, heads, q_len, _, _, is_cuda = (
         | 
| 106 | 
            +
                        *q.shape,
         | 
| 107 | 
            +
                        k.shape[-2],
         | 
| 108 | 
            +
                        q.is_cuda,
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # Check if mask exists and expand to compatible shape
         | 
| 112 | 
            +
                    # The mask is B L, so it would have to be expanded to B H N L
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    if exists(mask):
         | 
| 115 | 
            +
                        mask = mask.expand(-1, heads, q_len, -1)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    # Check if there is a compatible device for flash attention
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    config = self.cuda_config if is_cuda else self.cpu_config
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    with torch.backends.cuda.sdp_kernel(**config._asdict()):
         | 
| 124 | 
            +
                        out = F.scaled_dot_product_attention(
         | 
| 125 | 
            +
                            q,
         | 
| 126 | 
            +
                            k,
         | 
| 127 | 
            +
                            v,
         | 
| 128 | 
            +
                            attn_mask=mask,
         | 
| 129 | 
            +
                            dropout_p=self.dropout if self.training else 0.0,
         | 
| 130 | 
            +
                        )
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    return out
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def forward(self, q, k, v, mask=None):
         | 
| 135 | 
            +
                    """
         | 
| 136 | 
            +
                    einstein notation
         | 
| 137 | 
            +
                    b - batch
         | 
| 138 | 
            +
                    h - heads
         | 
| 139 | 
            +
                    n, i, j - sequence length (base sequence length, source, target)
         | 
| 140 | 
            +
                    d - feature dimension
         | 
| 141 | 
            +
                    """
         | 
| 142 | 
            +
                    scale = q.shape[-1] ** -0.5
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    if exists(mask) and mask.ndim != 4:
         | 
| 145 | 
            +
                        mask = rearrange(mask, "b j -> b 1 1 j")
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    if self.flash:
         | 
| 148 | 
            +
                        return self.flash_attn(q, k, v, mask=mask)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # similarity
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    # l2 distance
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    if self.l2_dist:
         | 
| 157 | 
            +
                        # -cdist squared == (-q^2 + 2qk - k^2)
         | 
| 158 | 
            +
                        # so simply work off the qk above
         | 
| 159 | 
            +
                        q_squared = reduce(q**2, "b h i d -> b h i 1", "sum")
         | 
| 160 | 
            +
                        k_squared = reduce(k**2, "b h j d -> b h 1 j", "sum")
         | 
| 161 | 
            +
                        sim = sim * 2 - q_squared - k_squared
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    # key padding mask
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    if exists(mask):
         | 
| 166 | 
            +
                        sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # attention
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    attn = sim.softmax(dim=-1)
         | 
| 171 | 
            +
                    attn = self.attn_dropout(attn)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    # aggregate values
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    out = einsum("b h i j, b h j d -> b h i d", attn, v)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    return out
         | 
| 178 | 
            +
             | 
| 179 | 
            +
             | 
| 180 | 
            +
            # helper
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            def exists(val):  # noqa: F811
         | 
| 184 | 
            +
                return val is not None
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            def default(val, d):
         | 
| 188 | 
            +
                return val if exists(val) else d
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            def inner_dot_product(x, y, *, dim=-1, keepdim=True):
         | 
| 192 | 
            +
                return (x * y).sum(dim=dim, keepdim=keepdim)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            # layernorm
         | 
| 196 | 
            +
             | 
| 197 | 
            +
             | 
| 198 | 
            +
            class LayerNorm(nn.Module):
         | 
| 199 | 
            +
                def __init__(self, dim):
         | 
| 200 | 
            +
                    super().__init__()
         | 
| 201 | 
            +
                    self.gamma = nn.Parameter(torch.ones(dim))
         | 
| 202 | 
            +
                    self.register_buffer("beta", torch.zeros(dim))
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def forward(self, x):
         | 
| 205 | 
            +
                    return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
             | 
| 208 | 
            +
            # equivariant modules
         | 
| 209 | 
            +
             | 
| 210 | 
            +
             | 
| 211 | 
            +
            class VNLinear(nn.Module):
         | 
| 212 | 
            +
                def __init__(self, dim_in, dim_out, bias_epsilon=0.0):
         | 
| 213 | 
            +
                    super().__init__()
         | 
| 214 | 
            +
                    self.weight = nn.Parameter(torch.randn(dim_out, dim_in))
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    self.bias = None
         | 
| 217 | 
            +
                    self.bias_epsilon = bias_epsilon
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    # in this paper, they propose going for quasi-equivariance with a small bias, controllable with epsilon, which they claim lead to better stability and results
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    if bias_epsilon > 0.0:
         | 
| 222 | 
            +
                        self.bias = nn.Parameter(torch.randn(dim_out))
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                def forward(self, x):
         | 
| 225 | 
            +
                    out = einsum("... i c, o i -> ... o c", x, self.weight)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    if exists(self.bias):
         | 
| 228 | 
            +
                        bias = F.normalize(self.bias, dim=-1) * self.bias_epsilon
         | 
| 229 | 
            +
                        out = out + rearrange(bias, "... -> ... 1")
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    return out
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            class VNReLU(nn.Module):
         | 
| 235 | 
            +
                def __init__(self, dim, eps=1e-6):
         | 
| 236 | 
            +
                    super().__init__()
         | 
| 237 | 
            +
                    self.eps = eps
         | 
| 238 | 
            +
                    self.W = nn.Parameter(torch.randn(dim, dim))
         | 
| 239 | 
            +
                    self.U = nn.Parameter(torch.randn(dim, dim))
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                def forward(self, x):
         | 
| 242 | 
            +
                    q = einsum("... i c, o i -> ... o c", x, self.W)
         | 
| 243 | 
            +
                    k = einsum("... i c, o i -> ... o c", x, self.U)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    qk = inner_dot_product(q, k)
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    k_norm = k.norm(dim=-1, keepdim=True).clamp(min=self.eps)
         | 
| 248 | 
            +
                    q_projected_on_k = q - inner_dot_product(q, k / k_norm) * k
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    out = torch.where(qk >= 0.0, q, q_projected_on_k)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    return out
         | 
| 253 | 
            +
             | 
| 254 | 
            +
             | 
| 255 | 
            +
            class VNAttention(nn.Module):
         | 
| 256 | 
            +
                def __init__(
         | 
| 257 | 
            +
                    self,
         | 
| 258 | 
            +
                    dim,
         | 
| 259 | 
            +
                    dim_head=64,
         | 
| 260 | 
            +
                    heads=8,
         | 
| 261 | 
            +
                    dim_coor=3,
         | 
| 262 | 
            +
                    bias_epsilon=0.0,
         | 
| 263 | 
            +
                    l2_dist_attn=False,
         | 
| 264 | 
            +
                    flash=False,
         | 
| 265 | 
            +
                    num_latents=None,  # setting this would enable perceiver-like cross attention from latents to sequence, with the latents derived from VNWeightedPool
         | 
| 266 | 
            +
                ):
         | 
| 267 | 
            +
                    super().__init__()
         | 
| 268 | 
            +
                    assert not (
         | 
| 269 | 
            +
                        l2_dist_attn and flash
         | 
| 270 | 
            +
                    ), "l2 distance attention is not compatible with flash attention"
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    self.scale = (dim_coor * dim_head) ** -0.5
         | 
| 273 | 
            +
                    dim_inner = dim_head * heads
         | 
| 274 | 
            +
                    self.heads = heads
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    self.to_q_input = None
         | 
| 277 | 
            +
                    if exists(num_latents):
         | 
| 278 | 
            +
                        self.to_q_input = VNWeightedPool(
         | 
| 279 | 
            +
                            dim, num_pooled_tokens=num_latents, squeeze_out_pooled_dim=False
         | 
| 280 | 
            +
                        )
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    self.to_q = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
         | 
| 283 | 
            +
                    self.to_k = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
         | 
| 284 | 
            +
                    self.to_v = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
         | 
| 285 | 
            +
                    self.to_out = VNLinear(dim_inner, dim, bias_epsilon=bias_epsilon)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    if l2_dist_attn and not exists(num_latents):
         | 
| 288 | 
            +
                        # tied queries and keys for l2 distance attention, and not perceiver-like attention
         | 
| 289 | 
            +
                        self.to_k = self.to_q
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    self.attend = Attend(flash=flash, l2_dist=l2_dist_attn)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                def forward(self, x, mask=None):
         | 
| 294 | 
            +
                    """
         | 
| 295 | 
            +
                    einstein notation
         | 
| 296 | 
            +
                    b - batch
         | 
| 297 | 
            +
                    n - sequence
         | 
| 298 | 
            +
                    h - heads
         | 
| 299 | 
            +
                    d - feature dimension (channels)
         | 
| 300 | 
            +
                    c - coordinate dimension (3 for 3d space)
         | 
| 301 | 
            +
                    i - source sequence dimension
         | 
| 302 | 
            +
                    j - target sequence dimension
         | 
| 303 | 
            +
                    """
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    c = x.shape[-1]
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    if exists(self.to_q_input):
         | 
| 308 | 
            +
                        q_input = self.to_q_input(x, mask=mask)
         | 
| 309 | 
            +
                    else:
         | 
| 310 | 
            +
                        q_input = x
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    q, k, v = self.to_q(q_input), self.to_k(x), self.to_v(x)
         | 
| 313 | 
            +
                    q, k, v = map(
         | 
| 314 | 
            +
                        lambda t: rearrange(t, "b n (h d) c -> b h n (d c)", h=self.heads),
         | 
| 315 | 
            +
                        (q, k, v),
         | 
| 316 | 
            +
                    )
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    out = self.attend(q, k, v, mask=mask)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    out = rearrange(out, "b h n (d c) -> b n (h d) c", c=c)
         | 
| 321 | 
            +
                    return self.to_out(out)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
             | 
| 324 | 
            +
            def VNFeedForward(dim, mult=4, bias_epsilon=0.0):
         | 
| 325 | 
            +
                dim_inner = int(dim * mult)
         | 
| 326 | 
            +
                return nn.Sequential(
         | 
| 327 | 
            +
                    VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon),
         | 
| 328 | 
            +
                    VNReLU(dim_inner),
         | 
| 329 | 
            +
                    VNLinear(dim_inner, dim, bias_epsilon=bias_epsilon),
         | 
| 330 | 
            +
                )
         | 
| 331 | 
            +
             | 
| 332 | 
            +
             | 
| 333 | 
            +
            class VNLayerNorm(nn.Module):
         | 
| 334 | 
            +
                def __init__(self, dim, eps=1e-6):
         | 
| 335 | 
            +
                    super().__init__()
         | 
| 336 | 
            +
                    self.eps = eps
         | 
| 337 | 
            +
                    self.ln = LayerNorm(dim)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                def forward(self, x):
         | 
| 340 | 
            +
                    norms = x.norm(dim=-1)
         | 
| 341 | 
            +
                    x = x / rearrange(norms.clamp(min=self.eps), "... -> ... 1")
         | 
| 342 | 
            +
                    ln_out = self.ln(norms)
         | 
| 343 | 
            +
                    return x * rearrange(ln_out, "... -> ... 1")
         | 
| 344 | 
            +
             | 
| 345 | 
            +
             | 
| 346 | 
            +
            class VNWeightedPool(nn.Module):
         | 
| 347 | 
            +
                def __init__(
         | 
| 348 | 
            +
                    self, dim, dim_out=None, num_pooled_tokens=1, squeeze_out_pooled_dim=True
         | 
| 349 | 
            +
                ):
         | 
| 350 | 
            +
                    super().__init__()
         | 
| 351 | 
            +
                    dim_out = default(dim_out, dim)
         | 
| 352 | 
            +
                    self.weight = nn.Parameter(torch.randn(num_pooled_tokens, dim, dim_out))
         | 
| 353 | 
            +
                    self.squeeze_out_pooled_dim = num_pooled_tokens == 1 and squeeze_out_pooled_dim
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                def forward(self, x, mask=None):
         | 
| 356 | 
            +
                    if exists(mask):
         | 
| 357 | 
            +
                        mask = rearrange(mask, "b n -> b n 1 1")
         | 
| 358 | 
            +
                        x = x.masked_fill(~mask, 0.0)
         | 
| 359 | 
            +
                        numer = reduce(x, "b n d c -> b d c", "sum")
         | 
| 360 | 
            +
                        denom = mask.sum(dim=1)
         | 
| 361 | 
            +
                        mean_pooled = numer / denom.clamp(min=1e-6)
         | 
| 362 | 
            +
                    else:
         | 
| 363 | 
            +
                        mean_pooled = reduce(x, "b n d c -> b d c", "mean")
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                    out = einsum("b d c, m d e -> b m e c", mean_pooled, self.weight)
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    if not self.squeeze_out_pooled_dim:
         | 
| 368 | 
            +
                        return out
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    out = rearrange(out, "b 1 d c -> b d c")
         | 
| 371 | 
            +
                    return out
         | 
| 372 | 
            +
             | 
| 373 | 
            +
             | 
| 374 | 
            +
            # equivariant VN transformer encoder
         | 
| 375 | 
            +
             | 
| 376 | 
            +
             | 
| 377 | 
            +
            class VNTransformerEncoder(nn.Module):
         | 
| 378 | 
            +
                def __init__(
         | 
| 379 | 
            +
                    self,
         | 
| 380 | 
            +
                    dim,
         | 
| 381 | 
            +
                    *,
         | 
| 382 | 
            +
                    depth,
         | 
| 383 | 
            +
                    dim_head=64,
         | 
| 384 | 
            +
                    heads=8,
         | 
| 385 | 
            +
                    dim_coor=3,
         | 
| 386 | 
            +
                    ff_mult=4,
         | 
| 387 | 
            +
                    final_norm=False,
         | 
| 388 | 
            +
                    bias_epsilon=0.0,
         | 
| 389 | 
            +
                    l2_dist_attn=False,
         | 
| 390 | 
            +
                    flash_attn=False,
         | 
| 391 | 
            +
                ):
         | 
| 392 | 
            +
                    super().__init__()
         | 
| 393 | 
            +
                    self.dim = dim
         | 
| 394 | 
            +
                    self.dim_coor = dim_coor
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    self.layers = nn.ModuleList([])
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                    for _ in range(depth):
         | 
| 399 | 
            +
                        self.layers.append(
         | 
| 400 | 
            +
                            nn.ModuleList(
         | 
| 401 | 
            +
                                [
         | 
| 402 | 
            +
                                    VNAttention(
         | 
| 403 | 
            +
                                        dim=dim,
         | 
| 404 | 
            +
                                        dim_head=dim_head,
         | 
| 405 | 
            +
                                        heads=heads,
         | 
| 406 | 
            +
                                        bias_epsilon=bias_epsilon,
         | 
| 407 | 
            +
                                        l2_dist_attn=l2_dist_attn,
         | 
| 408 | 
            +
                                        flash=flash_attn,
         | 
| 409 | 
            +
                                    ),
         | 
| 410 | 
            +
                                    VNLayerNorm(dim),
         | 
| 411 | 
            +
                                    VNFeedForward(dim=dim, mult=ff_mult, bias_epsilon=bias_epsilon),
         | 
| 412 | 
            +
                                    VNLayerNorm(dim),
         | 
| 413 | 
            +
                                ]
         | 
| 414 | 
            +
                            )
         | 
| 415 | 
            +
                        )
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                    self.norm = VNLayerNorm(dim) if final_norm else nn.Identity()
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                def forward(self, x, mask=None):
         | 
| 420 | 
            +
                    *_, d, c = x.shape
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    assert (
         | 
| 423 | 
            +
                        x.ndim == 4 and d == self.dim and c == self.dim_coor
         | 
| 424 | 
            +
                    ), "input needs to be in the shape of (batch, seq, dim ({self.dim}), coordinate dim ({self.dim_coor}))"
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    for attn, attn_post_ln, ff, ff_post_ln in self.layers:
         | 
| 427 | 
            +
                        x = attn_post_ln(attn(x, mask=mask)) + x
         | 
| 428 | 
            +
                        x = ff_post_ln(ff(x)) + x
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    return self.norm(x)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
             | 
| 433 | 
            +
            # invariant layers
         | 
| 434 | 
            +
             | 
| 435 | 
            +
             | 
| 436 | 
            +
            class VNInvariant(nn.Module):
         | 
| 437 | 
            +
                def __init__(
         | 
| 438 | 
            +
                    self,
         | 
| 439 | 
            +
                    dim,
         | 
| 440 | 
            +
                    dim_coor=3,
         | 
| 441 | 
            +
                ):
         | 
| 442 | 
            +
                    super().__init__()
         | 
| 443 | 
            +
                    self.mlp = nn.Sequential(
         | 
| 444 | 
            +
                        VNLinear(dim, dim_coor), VNReLU(dim_coor), Rearrange("... d e -> ... e d")
         | 
| 445 | 
            +
                    )
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                def forward(self, x):
         | 
| 448 | 
            +
                    return einsum("b n d i, b n i o -> b n o", x, self.mlp(x))
         | 
| 449 | 
            +
             | 
| 450 | 
            +
             | 
| 451 | 
            +
            # main class
         | 
| 452 | 
            +
             | 
| 453 | 
            +
             | 
| 454 | 
            +
            class VNTransformer(nn.Module):
         | 
| 455 | 
            +
                def __init__(
         | 
| 456 | 
            +
                    self,
         | 
| 457 | 
            +
                    *,
         | 
| 458 | 
            +
                    dim,
         | 
| 459 | 
            +
                    depth,
         | 
| 460 | 
            +
                    num_tokens=None,
         | 
| 461 | 
            +
                    dim_feat=None,
         | 
| 462 | 
            +
                    dim_head=64,
         | 
| 463 | 
            +
                    heads=8,
         | 
| 464 | 
            +
                    dim_coor=3,
         | 
| 465 | 
            +
                    reduce_dim_out=True,
         | 
| 466 | 
            +
                    bias_epsilon=0.0,
         | 
| 467 | 
            +
                    l2_dist_attn=False,
         | 
| 468 | 
            +
                    flash_attn=False,
         | 
| 469 | 
            +
                    translation_equivariance=False,
         | 
| 470 | 
            +
                    translation_invariant=False,
         | 
| 471 | 
            +
                ):
         | 
| 472 | 
            +
                    super().__init__()
         | 
| 473 | 
            +
                    self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    dim_feat = default(dim_feat, 0)
         | 
| 476 | 
            +
                    self.dim_feat = dim_feat
         | 
| 477 | 
            +
                    self.dim_coor_total = dim_coor + dim_feat
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                    assert (int(translation_equivariance) + int(translation_invariant)) <= 1
         | 
| 480 | 
            +
                    self.translation_equivariance = translation_equivariance
         | 
| 481 | 
            +
                    self.translation_invariant = translation_invariant
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    self.vn_proj_in = nn.Sequential(
         | 
| 484 | 
            +
                        Rearrange("... c -> ... 1 c"), VNLinear(1, dim, bias_epsilon=bias_epsilon)
         | 
| 485 | 
            +
                    )
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    self.encoder = VNTransformerEncoder(
         | 
| 488 | 
            +
                        dim=dim,
         | 
| 489 | 
            +
                        depth=depth,
         | 
| 490 | 
            +
                        dim_head=dim_head,
         | 
| 491 | 
            +
                        heads=heads,
         | 
| 492 | 
            +
                        bias_epsilon=bias_epsilon,
         | 
| 493 | 
            +
                        dim_coor=self.dim_coor_total,
         | 
| 494 | 
            +
                        l2_dist_attn=l2_dist_attn,
         | 
| 495 | 
            +
                        flash_attn=flash_attn,
         | 
| 496 | 
            +
                    )
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    if reduce_dim_out:
         | 
| 499 | 
            +
                        self.vn_proj_out = nn.Sequential(
         | 
| 500 | 
            +
                            VNLayerNorm(dim),
         | 
| 501 | 
            +
                            VNLinear(dim, 1, bias_epsilon=bias_epsilon),
         | 
| 502 | 
            +
                            Rearrange("... 1 c -> ... c"),
         | 
| 503 | 
            +
                        )
         | 
| 504 | 
            +
                    else:
         | 
| 505 | 
            +
                        self.vn_proj_out = nn.Identity()
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                def forward(
         | 
| 508 | 
            +
                    self, coors, *, feats=None, mask=None, return_concatted_coors_and_feats=False
         | 
| 509 | 
            +
                ):
         | 
| 510 | 
            +
                    if self.translation_equivariance or self.translation_invariant:
         | 
| 511 | 
            +
                        coors_mean = reduce(coors, "... c -> c", "mean")
         | 
| 512 | 
            +
                        coors = coors - coors_mean
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                    x = coors  # [batch, num_points, 3]
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                    if exists(feats):
         | 
| 517 | 
            +
                        if feats.dtype == torch.long:
         | 
| 518 | 
            +
                            assert exists(
         | 
| 519 | 
            +
                                self.token_emb
         | 
| 520 | 
            +
                            ), "num_tokens must be given to the VNTransformer (to build the Embedding), if the features are to be given as indices"
         | 
| 521 | 
            +
                            feats = self.token_emb(feats)
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                        assert (
         | 
| 524 | 
            +
                            feats.shape[-1] == self.dim_feat
         | 
| 525 | 
            +
                        ), f"dim_feat should be set to {feats.shape[-1]}"
         | 
| 526 | 
            +
                        x = torch.cat((x, feats), dim=-1)  # [batch, num_points, 3 + dim_feat]
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    assert x.shape[-1] == self.dim_coor_total
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    x = self.vn_proj_in(x)  # [batch, num_points, hidden_dim, 3 + dim_feat]
         | 
| 531 | 
            +
                    x = self.encoder(x, mask=mask)  # [batch, num_points, hidden_dim, 3 + dim_feat]
         | 
| 532 | 
            +
                    x = self.vn_proj_out(x)  # [batch, num_points, 3 + dim_feat]
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    coors_out, feats_out = (
         | 
| 535 | 
            +
                        x[..., :3],
         | 
| 536 | 
            +
                        x[..., 3:],
         | 
| 537 | 
            +
                    )  # [batch, num_points, 3], [batch, num_points, dim_feat]
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    if self.translation_equivariance:
         | 
| 540 | 
            +
                        coors_out = coors_out + coors_mean
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                    if not exists(feats):
         | 
| 543 | 
            +
                        return coors_out
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    if return_concatted_coors_and_feats:
         | 
| 546 | 
            +
                        return torch.cat((coors_out, feats_out), dim=-1)
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                    return coors_out, feats_out
         | 
    	
        spar3d/models/illumination/reni/env_map.py
    ADDED
    
    | @@ -0,0 +1,93 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass, field
         | 
| 2 | 
            +
            from typing import Dict, List, Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from jaxtyping import Float
         | 
| 6 | 
            +
            from torch import Tensor
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from spar3d.models.utils import BaseModule
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .field import RENIField
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def _direction_from_coordinate(
         | 
| 14 | 
            +
                coordinate: Float[Tensor, "*B 2"],
         | 
| 15 | 
            +
            ) -> Float[Tensor, "*B 3"]:
         | 
| 16 | 
            +
                # OpenGL Convention
         | 
| 17 | 
            +
                # +X Right
         | 
| 18 | 
            +
                # +Y Up
         | 
| 19 | 
            +
                # +Z Backward
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                u, v = coordinate.unbind(-1)
         | 
| 22 | 
            +
                theta = (2 * torch.pi * u) - torch.pi
         | 
| 23 | 
            +
                phi = torch.pi * v
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                dir = torch.stack(
         | 
| 26 | 
            +
                    [
         | 
| 27 | 
            +
                        theta.sin() * phi.sin(),
         | 
| 28 | 
            +
                        phi.cos(),
         | 
| 29 | 
            +
                        -1 * theta.cos() * phi.sin(),
         | 
| 30 | 
            +
                    ],
         | 
| 31 | 
            +
                    -1,
         | 
| 32 | 
            +
                )
         | 
| 33 | 
            +
                return dir
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def _get_sample_coordinates(
         | 
| 37 | 
            +
                resolution: List[int], device: Optional[torch.device] = None
         | 
| 38 | 
            +
            ) -> Float[Tensor, "H W 2"]:
         | 
| 39 | 
            +
                return torch.stack(
         | 
| 40 | 
            +
                    torch.meshgrid(
         | 
| 41 | 
            +
                        (torch.arange(resolution[1], device=device) + 0.5) / resolution[1],
         | 
| 42 | 
            +
                        (torch.arange(resolution[0], device=device) + 0.5) / resolution[0],
         | 
| 43 | 
            +
                        indexing="xy",
         | 
| 44 | 
            +
                    ),
         | 
| 45 | 
            +
                    -1,
         | 
| 46 | 
            +
                )
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            class RENIEnvMap(BaseModule):
         | 
| 50 | 
            +
                @dataclass
         | 
| 51 | 
            +
                class Config(BaseModule.Config):
         | 
| 52 | 
            +
                    reni_config: dict = field(default_factory=dict)
         | 
| 53 | 
            +
                    resolution: int = 128
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                cfg: Config
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def configure(self):
         | 
| 58 | 
            +
                    self.field = RENIField(self.cfg.reni_config)
         | 
| 59 | 
            +
                    resolution = (self.cfg.resolution, self.cfg.resolution * 2)
         | 
| 60 | 
            +
                    sample_directions = _direction_from_coordinate(
         | 
| 61 | 
            +
                        _get_sample_coordinates(resolution)
         | 
| 62 | 
            +
                    )
         | 
| 63 | 
            +
                    self.img_shape = sample_directions.shape[:-1]
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    sample_directions_flat = sample_directions.view(-1, 3)
         | 
| 66 | 
            +
                    # Lastly these have y up but reni expects z up. Rotate 90 degrees on x axis
         | 
| 67 | 
            +
                    sample_directions_flat = torch.stack(
         | 
| 68 | 
            +
                        [
         | 
| 69 | 
            +
                            sample_directions_flat[:, 0],
         | 
| 70 | 
            +
                            -sample_directions_flat[:, 2],
         | 
| 71 | 
            +
                            sample_directions_flat[:, 1],
         | 
| 72 | 
            +
                        ],
         | 
| 73 | 
            +
                        -1,
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
                    self.sample_directions = torch.nn.Parameter(
         | 
| 76 | 
            +
                        sample_directions_flat, requires_grad=False
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def forward(
         | 
| 80 | 
            +
                    self,
         | 
| 81 | 
            +
                    latent_codes: Float[Tensor, "B latent_dim 3"],
         | 
| 82 | 
            +
                    rotation: Optional[Float[Tensor, "B 3 3"]] = None,
         | 
| 83 | 
            +
                    scale: Optional[Float[Tensor, "B"]] = None,
         | 
| 84 | 
            +
                ) -> Dict[str, Tensor]:
         | 
| 85 | 
            +
                    return {
         | 
| 86 | 
            +
                        k: v.view(latent_codes.shape[0], *self.img_shape, -1)
         | 
| 87 | 
            +
                        for k, v in self.field(
         | 
| 88 | 
            +
                            self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1),
         | 
| 89 | 
            +
                            latent_codes,
         | 
| 90 | 
            +
                            rotation=rotation,
         | 
| 91 | 
            +
                            scale=scale,
         | 
| 92 | 
            +
                        ).items()
         | 
| 93 | 
            +
                    }
         | 
    	
        spar3d/models/illumination/reni/field.py
    ADDED
    
    | @@ -0,0 +1,736 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The University of York. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # Modified by Mark Boss
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            """RENI field"""
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import contextlib
         | 
| 20 | 
            +
            from dataclasses import dataclass
         | 
| 21 | 
            +
            from typing import Dict, Literal, Optional
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            import torch
         | 
| 24 | 
            +
            from einops.layers.torch import Rearrange
         | 
| 25 | 
            +
            from jaxtyping import Float
         | 
| 26 | 
            +
            from torch import Tensor, nn
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            from spar3d.models.network import get_activation_module, trunc_exp
         | 
| 29 | 
            +
            from spar3d.models.utils import BaseModule
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            from .components.film_siren import FiLMSiren
         | 
| 32 | 
            +
            from .components.siren import Siren
         | 
| 33 | 
            +
            from .components.transformer_decoder import Decoder
         | 
| 34 | 
            +
            from .components.vn_layers import VNInvariant, VNLinear
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # from nerfstudio.cameras.rays import RaySamples
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor:
         | 
| 40 | 
            +
                """Computes the expected value of sin(y) where y ~ N(x_means, x_vars)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                Args:
         | 
| 43 | 
            +
                    x_means: Mean values.
         | 
| 44 | 
            +
                    x_vars: Variance of values.
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                Returns:
         | 
| 47 | 
            +
                    torch.Tensor: The expected value of sin.
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                return torch.exp(-0.5 * x_vars) * torch.sin(x_means)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class NeRFEncoding(torch.nn.Module):
         | 
| 54 | 
            +
                """Multi-scale sinousoidal encodings. Support ``integrated positional encodings`` if covariances are provided.
         | 
| 55 | 
            +
                Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp.
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                Args:
         | 
| 58 | 
            +
                    in_dim: Input dimension of tensor
         | 
| 59 | 
            +
                    num_frequencies: Number of encoded frequencies per axis
         | 
| 60 | 
            +
                    min_freq_exp: Minimum frequency exponent
         | 
| 61 | 
            +
                    max_freq_exp: Maximum frequency exponent
         | 
| 62 | 
            +
                    include_input: Append the input coordinate to the encoding
         | 
| 63 | 
            +
                """
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def __init__(
         | 
| 66 | 
            +
                    self,
         | 
| 67 | 
            +
                    in_dim: int,
         | 
| 68 | 
            +
                    num_frequencies: int,
         | 
| 69 | 
            +
                    min_freq_exp: float,
         | 
| 70 | 
            +
                    max_freq_exp: float,
         | 
| 71 | 
            +
                    include_input: bool = False,
         | 
| 72 | 
            +
                    off_axis: bool = False,
         | 
| 73 | 
            +
                ) -> None:
         | 
| 74 | 
            +
                    super().__init__()
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self.in_dim = in_dim
         | 
| 77 | 
            +
                    self.num_frequencies = num_frequencies
         | 
| 78 | 
            +
                    self.min_freq = min_freq_exp
         | 
| 79 | 
            +
                    self.max_freq = max_freq_exp
         | 
| 80 | 
            +
                    self.include_input = include_input
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    self.off_axis = off_axis
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    self.P = torch.tensor(
         | 
| 85 | 
            +
                        [
         | 
| 86 | 
            +
                            [0.8506508, 0, 0.5257311],
         | 
| 87 | 
            +
                            [0.809017, 0.5, 0.309017],
         | 
| 88 | 
            +
                            [0.5257311, 0.8506508, 0],
         | 
| 89 | 
            +
                            [1, 0, 0],
         | 
| 90 | 
            +
                            [0.809017, 0.5, -0.309017],
         | 
| 91 | 
            +
                            [0.8506508, 0, -0.5257311],
         | 
| 92 | 
            +
                            [0.309017, 0.809017, -0.5],
         | 
| 93 | 
            +
                            [0, 0.5257311, -0.8506508],
         | 
| 94 | 
            +
                            [0.5, 0.309017, -0.809017],
         | 
| 95 | 
            +
                            [0, 1, 0],
         | 
| 96 | 
            +
                            [-0.5257311, 0.8506508, 0],
         | 
| 97 | 
            +
                            [-0.309017, 0.809017, -0.5],
         | 
| 98 | 
            +
                            [0, 0.5257311, 0.8506508],
         | 
| 99 | 
            +
                            [-0.309017, 0.809017, 0.5],
         | 
| 100 | 
            +
                            [0.309017, 0.809017, 0.5],
         | 
| 101 | 
            +
                            [0.5, 0.309017, 0.809017],
         | 
| 102 | 
            +
                            [0.5, -0.309017, 0.809017],
         | 
| 103 | 
            +
                            [0, 0, 1],
         | 
| 104 | 
            +
                            [-0.5, 0.309017, 0.809017],
         | 
| 105 | 
            +
                            [-0.809017, 0.5, 0.309017],
         | 
| 106 | 
            +
                            [-0.809017, 0.5, -0.309017],
         | 
| 107 | 
            +
                        ]
         | 
| 108 | 
            +
                    ).T
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def get_out_dim(self) -> int:
         | 
| 111 | 
            +
                    if self.in_dim is None:
         | 
| 112 | 
            +
                        raise ValueError("Input dimension has not been set")
         | 
| 113 | 
            +
                    out_dim = self.in_dim * self.num_frequencies * 2
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    if self.off_axis:
         | 
| 116 | 
            +
                        out_dim = self.P.shape[1] * self.num_frequencies * 2
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    if self.include_input:
         | 
| 119 | 
            +
                        out_dim += self.in_dim
         | 
| 120 | 
            +
                    return out_dim
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def forward(
         | 
| 123 | 
            +
                    self,
         | 
| 124 | 
            +
                    in_tensor: Float[Tensor, "*b input_dim"],
         | 
| 125 | 
            +
                    covs: Optional[Float[Tensor, "*b input_dim input_dim"]] = None,
         | 
| 126 | 
            +
                ) -> Float[Tensor, "*b output_dim"]:
         | 
| 127 | 
            +
                    """Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed
         | 
| 128 | 
            +
                        in mip-NeRF.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    Args:
         | 
| 131 | 
            +
                        in_tensor: For best performance, the input tensor should be between 0 and 1.
         | 
| 132 | 
            +
                        covs: Covariances of input points.
         | 
| 133 | 
            +
                    Returns:
         | 
| 134 | 
            +
                        Output values will be between -1 and 1
         | 
| 135 | 
            +
                    """
         | 
| 136 | 
            +
                    # TODO check scaling here but just comment it for now
         | 
| 137 | 
            +
                    # in_tensor = 2 * torch.pi * in_tensor  # scale to [0, 2pi]
         | 
| 138 | 
            +
                    freqs = 2 ** torch.linspace(
         | 
| 139 | 
            +
                        self.min_freq, self.max_freq, self.num_frequencies
         | 
| 140 | 
            +
                    ).to(in_tensor.device)
         | 
| 141 | 
            +
                    # freqs = 2 ** (
         | 
| 142 | 
            +
                    #    torch.sin(torch.linspace(self.min_freq, torch.pi / 2.0, self.num_frequencies)) * self.max_freq
         | 
| 143 | 
            +
                    # ).to(in_tensor.device)
         | 
| 144 | 
            +
                    # freqs = 2 ** (
         | 
| 145 | 
            +
                    #     torch.linspace(self.min_freq, 1.0, self.num_frequencies).to(in_tensor.device) ** 0.2 * self.max_freq
         | 
| 146 | 
            +
                    # )
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    if self.off_axis:
         | 
| 149 | 
            +
                        scaled_inputs = (
         | 
| 150 | 
            +
                            torch.matmul(in_tensor, self.P.to(in_tensor.device))[..., None] * freqs
         | 
| 151 | 
            +
                        )
         | 
| 152 | 
            +
                    else:
         | 
| 153 | 
            +
                        scaled_inputs = (
         | 
| 154 | 
            +
                            in_tensor[..., None] * freqs
         | 
| 155 | 
            +
                        )  # [..., "input_dim", "num_scales"]
         | 
| 156 | 
            +
                    scaled_inputs = scaled_inputs.view(
         | 
| 157 | 
            +
                        *scaled_inputs.shape[:-2], -1
         | 
| 158 | 
            +
                    )  # [..., "input_dim" * "num_scales"]
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    if covs is None:
         | 
| 161 | 
            +
                        encoded_inputs = torch.sin(
         | 
| 162 | 
            +
                            torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1)
         | 
| 163 | 
            +
                        )
         | 
| 164 | 
            +
                    else:
         | 
| 165 | 
            +
                        input_var = (
         | 
| 166 | 
            +
                            torch.diagonal(covs, dim1=-2, dim2=-1)[..., :, None]
         | 
| 167 | 
            +
                            * freqs[None, :] ** 2
         | 
| 168 | 
            +
                        )
         | 
| 169 | 
            +
                        input_var = input_var.reshape((*input_var.shape[:-2], -1))
         | 
| 170 | 
            +
                        encoded_inputs = expected_sin(
         | 
| 171 | 
            +
                            torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1),
         | 
| 172 | 
            +
                            torch.cat(2 * [input_var], dim=-1),
         | 
| 173 | 
            +
                        )
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    if self.include_input:
         | 
| 176 | 
            +
                        encoded_inputs = torch.cat([encoded_inputs, in_tensor], dim=-1)
         | 
| 177 | 
            +
                    return encoded_inputs
         | 
| 178 | 
            +
             | 
| 179 | 
            +
             | 
| 180 | 
            +
            class RENIField(BaseModule):
         | 
| 181 | 
            +
                @dataclass
         | 
| 182 | 
            +
                class Config(BaseModule.Config):
         | 
| 183 | 
            +
                    """Configuration for model instantiation"""
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    fixed_decoder: bool = False
         | 
| 186 | 
            +
                    """Whether to fix the decoder weights"""
         | 
| 187 | 
            +
                    equivariance: str = "SO2"
         | 
| 188 | 
            +
                    """Type of equivariance to use: None, SO2, SO3"""
         | 
| 189 | 
            +
                    axis_of_invariance: str = "y"
         | 
| 190 | 
            +
                    """Which axis should SO2 equivariance be invariant to: x, y, z"""
         | 
| 191 | 
            +
                    invariant_function: str = "GramMatrix"
         | 
| 192 | 
            +
                    """Type of invariant function to use: GramMatrix, VN"""
         | 
| 193 | 
            +
                    conditioning: str = "Concat"
         | 
| 194 | 
            +
                    """Type of conditioning to use: FiLM, Concat, Attention"""
         | 
| 195 | 
            +
                    positional_encoding: str = "NeRF"
         | 
| 196 | 
            +
                    """Type of positional encoding to use. Currently only NeRF is supported"""
         | 
| 197 | 
            +
                    encoded_input: str = "Directions"
         | 
| 198 | 
            +
                    """Type of input to encode: None, Directions, Conditioning, Both"""
         | 
| 199 | 
            +
                    latent_dim: int = 36
         | 
| 200 | 
            +
                    """Dimensionality of latent code, N for a latent code size of (N x 3)"""
         | 
| 201 | 
            +
                    hidden_layers: int = 3
         | 
| 202 | 
            +
                    """Number of hidden layers"""
         | 
| 203 | 
            +
                    hidden_features: int = 128
         | 
| 204 | 
            +
                    """Number of hidden features"""
         | 
| 205 | 
            +
                    mapping_layers: int = 3
         | 
| 206 | 
            +
                    """Number of mapping layers"""
         | 
| 207 | 
            +
                    mapping_features: int = 128
         | 
| 208 | 
            +
                    """Number of mapping features"""
         | 
| 209 | 
            +
                    num_attention_heads: int = 8
         | 
| 210 | 
            +
                    """Number of attention heads"""
         | 
| 211 | 
            +
                    num_attention_layers: int = 3
         | 
| 212 | 
            +
                    """Number of attention layers"""
         | 
| 213 | 
            +
                    out_features: int = 3  # RGB
         | 
| 214 | 
            +
                    """Number of output features"""
         | 
| 215 | 
            +
                    last_layer_linear: bool = False
         | 
| 216 | 
            +
                    """Whether to use a linear layer as the last layer"""
         | 
| 217 | 
            +
                    output_activation: str = "exp"
         | 
| 218 | 
            +
                    """Activation function for output layer: sigmoid, tanh, relu, exp, None"""
         | 
| 219 | 
            +
                    first_omega_0: float = 30.0
         | 
| 220 | 
            +
                    """Omega_0 for first layer"""
         | 
| 221 | 
            +
                    hidden_omega_0: float = 30.0
         | 
| 222 | 
            +
                    """Omega_0 for hidden layers"""
         | 
| 223 | 
            +
                    fixed_decoder: bool = False
         | 
| 224 | 
            +
                    """Whether to fix the decoder weights"""
         | 
| 225 | 
            +
                    old_implementation: bool = False
         | 
| 226 | 
            +
                    """Whether to match implementation of old RENI, when using old checkpoints"""
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                cfg: Config
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                def configure(self):
         | 
| 231 | 
            +
                    self.equivariance = self.cfg.equivariance
         | 
| 232 | 
            +
                    self.conditioning = self.cfg.conditioning
         | 
| 233 | 
            +
                    self.latent_dim = self.cfg.latent_dim
         | 
| 234 | 
            +
                    self.hidden_layers = self.cfg.hidden_layers
         | 
| 235 | 
            +
                    self.hidden_features = self.cfg.hidden_features
         | 
| 236 | 
            +
                    self.mapping_layers = self.cfg.mapping_layers
         | 
| 237 | 
            +
                    self.mapping_features = self.cfg.mapping_features
         | 
| 238 | 
            +
                    self.out_features = self.cfg.out_features
         | 
| 239 | 
            +
                    self.last_layer_linear = self.cfg.last_layer_linear
         | 
| 240 | 
            +
                    self.output_activation = self.cfg.output_activation
         | 
| 241 | 
            +
                    self.first_omega_0 = self.cfg.first_omega_0
         | 
| 242 | 
            +
                    self.hidden_omega_0 = self.cfg.hidden_omega_0
         | 
| 243 | 
            +
                    self.old_implementation = self.cfg.old_implementation
         | 
| 244 | 
            +
                    self.axis_of_invariance = ["x", "y", "z"].index(self.cfg.axis_of_invariance)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    self.fixed_decoder = self.cfg.fixed_decoder
         | 
| 247 | 
            +
                    if self.cfg.invariant_function == "GramMatrix":
         | 
| 248 | 
            +
                        self.invariant_function = self.gram_matrix_invariance
         | 
| 249 | 
            +
                    else:
         | 
| 250 | 
            +
                        self.vn_proj_in = nn.Sequential(
         | 
| 251 | 
            +
                            Rearrange("... c -> ... 1 c"),
         | 
| 252 | 
            +
                            VNLinear(dim_in=1, dim_out=1, bias_epsilon=0),
         | 
| 253 | 
            +
                        )
         | 
| 254 | 
            +
                        dim_coor = 2 if self.cfg.equivariance == "SO2" else 3
         | 
| 255 | 
            +
                        self.vn_invar = VNInvariant(dim=1, dim_coor=dim_coor)
         | 
| 256 | 
            +
                        self.invariant_function = self.vn_invariance
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    self.network = self.setup_network()
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    if self.fixed_decoder:
         | 
| 261 | 
            +
                        for param in self.network.parameters():
         | 
| 262 | 
            +
                            param.requires_grad = False
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                        if self.cfg.invariant_function == "VN":
         | 
| 265 | 
            +
                            for param in self.vn_proj_in.parameters():
         | 
| 266 | 
            +
                                param.requires_grad = False
         | 
| 267 | 
            +
                            for param in self.vn_invar.parameters():
         | 
| 268 | 
            +
                                param.requires_grad = False
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                @contextlib.contextmanager
         | 
| 271 | 
            +
                def hold_decoder_fixed(self):
         | 
| 272 | 
            +
                    """Context manager to fix the decoder weights
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    Example usage:
         | 
| 275 | 
            +
                    ```
         | 
| 276 | 
            +
                    with instance_of_RENIField.hold_decoder_fixed():
         | 
| 277 | 
            +
                        # do stuff
         | 
| 278 | 
            +
                    ```
         | 
| 279 | 
            +
                    """
         | 
| 280 | 
            +
                    prev_state_network = {
         | 
| 281 | 
            +
                        name: p.requires_grad for name, p in self.network.named_parameters()
         | 
| 282 | 
            +
                    }
         | 
| 283 | 
            +
                    for param in self.network.parameters():
         | 
| 284 | 
            +
                        param.requires_grad = False
         | 
| 285 | 
            +
                    if self.cfg.invariant_function == "VN":
         | 
| 286 | 
            +
                        prev_state_proj_in = {
         | 
| 287 | 
            +
                            k: p.requires_grad for k, p in self.vn_proj_in.named_parameters()
         | 
| 288 | 
            +
                        }
         | 
| 289 | 
            +
                        prev_state_invar = {
         | 
| 290 | 
            +
                            k: p.requires_grad for k, p in self.vn_invar.named_parameters()
         | 
| 291 | 
            +
                        }
         | 
| 292 | 
            +
                        for param in self.vn_proj_in.parameters():
         | 
| 293 | 
            +
                            param.requires_grad = False
         | 
| 294 | 
            +
                        for param in self.vn_invar.parameters():
         | 
| 295 | 
            +
                            param.requires_grad = False
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    prev_decoder_state = self.fixed_decoder
         | 
| 298 | 
            +
                    self.fixed_decoder = True
         | 
| 299 | 
            +
                    try:
         | 
| 300 | 
            +
                        yield
         | 
| 301 | 
            +
                    finally:
         | 
| 302 | 
            +
                        # Restore the previous requires_grad state
         | 
| 303 | 
            +
                        for name, param in self.network.named_parameters():
         | 
| 304 | 
            +
                            param.requires_grad = prev_state_network[name]
         | 
| 305 | 
            +
                        if self.cfg.invariant_function == "VN":
         | 
| 306 | 
            +
                            for name, param in self.vn_proj_in.named_parameters():
         | 
| 307 | 
            +
                                param.requires_grad_(prev_state_proj_in[name])
         | 
| 308 | 
            +
                            for name, param in self.vn_invar.named_parameters():
         | 
| 309 | 
            +
                                param.requires_grad_(prev_state_invar[name])
         | 
| 310 | 
            +
                        self.fixed_decoder = prev_decoder_state
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                def vn_invariance(
         | 
| 313 | 
            +
                    self,
         | 
| 314 | 
            +
                    Z: Float[Tensor, "B latent_dim 3"],
         | 
| 315 | 
            +
                    D: Float[Tensor, "B num_rays 3"],
         | 
| 316 | 
            +
                    equivariance: Literal["None", "SO2", "SO3"] = "SO2",
         | 
| 317 | 
            +
                    axis_of_invariance: int = 1,
         | 
| 318 | 
            +
                ):
         | 
| 319 | 
            +
                    """Generates a batched invariant representation from latent code Z and direction coordinates D.
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    Args:
         | 
| 322 | 
            +
                        Z: [B, latent_dim, 3] - Latent code.
         | 
| 323 | 
            +
                        D: [B num_rays, 3] - Direction coordinates.
         | 
| 324 | 
            +
                        equivariance: The type of equivariance to use. Options are 'None', 'SO2', 'SO3'.
         | 
| 325 | 
            +
                        axis_of_invariance: The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis).
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    Returns:
         | 
| 328 | 
            +
                        Tuple[Tensor, Tensor]: directional_input, conditioning_input
         | 
| 329 | 
            +
                    """
         | 
| 330 | 
            +
                    assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2."
         | 
| 331 | 
            +
                    other_axes = [i for i in range(3) if i != axis_of_invariance]
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    B, latent_dim, _ = Z.shape
         | 
| 334 | 
            +
                    _, num_rays, _ = D.shape
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    if equivariance == "None":
         | 
| 337 | 
            +
                        # get inner product between latent code and direction coordinates
         | 
| 338 | 
            +
                        innerprod = torch.sum(
         | 
| 339 | 
            +
                            Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
         | 
| 340 | 
            +
                        )  # [B, num_rays, latent_dim]
         | 
| 341 | 
            +
                        z_input = (
         | 
| 342 | 
            +
                            Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3)
         | 
| 343 | 
            +
                        )  # [B, num_rays, latent_dim * 3]
         | 
| 344 | 
            +
                        return innerprod, z_input
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    if equivariance == "SO2":
         | 
| 347 | 
            +
                        z_other = torch.stack(
         | 
| 348 | 
            +
                            (Z[..., other_axes[0]], Z[..., other_axes[1]]), -1
         | 
| 349 | 
            +
                        )  # [B, latent_dim, 2]
         | 
| 350 | 
            +
                        d_other = torch.stack(
         | 
| 351 | 
            +
                            (D[..., other_axes[0]], D[..., other_axes[1]]), -1
         | 
| 352 | 
            +
                        ).unsqueeze(2)  # [B, num_rays, 1, 2]
         | 
| 353 | 
            +
                        d_other = d_other.expand(
         | 
| 354 | 
            +
                            B, num_rays, latent_dim, 2
         | 
| 355 | 
            +
                        )  # [B, num_rays, latent_dim, 2]
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                        z_other_emb = self.vn_proj_in(z_other)  # [B, latent_dim, 1, 2]
         | 
| 358 | 
            +
                        z_other_invar = self.vn_invar(z_other_emb)  # [B, latent_dim, 2]
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                        # Get invariant component of Z along the axis of invariance
         | 
| 361 | 
            +
                        z_invar = Z[..., axis_of_invariance].unsqueeze(-1)  # [B, latent_dim, 1]
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                        # Innerproduct between projection of Z and D on the plane orthogonal to the axis of invariance.
         | 
| 364 | 
            +
                        # This encodes the rotational information. This is rotation-equivariant to rotations of either Z
         | 
| 365 | 
            +
                        # or D and is invariant to rotations of both Z and D.
         | 
| 366 | 
            +
                        innerprod = (z_other.unsqueeze(1) * d_other).sum(
         | 
| 367 | 
            +
                            dim=-1
         | 
| 368 | 
            +
                        )  # [B, num_rays, latent_dim]
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                        # Compute norm along the axes orthogonal to the axis of invariance
         | 
| 371 | 
            +
                        d_other_norm = torch.sqrt(
         | 
| 372 | 
            +
                            D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2
         | 
| 373 | 
            +
                        ).unsqueeze(-1)  # [B num_rays, 1]
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                        # Get invariant component of D along the axis of invariance
         | 
| 376 | 
            +
                        d_invar = D[..., axis_of_invariance].unsqueeze(-1)  # [B, num_rays, 1]
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                        directional_input = torch.cat(
         | 
| 379 | 
            +
                            (innerprod, d_invar, d_other_norm), -1
         | 
| 380 | 
            +
                        )  # [B, num_rays, latent_dim + 2]
         | 
| 381 | 
            +
                        conditioning_input = (
         | 
| 382 | 
            +
                            torch.cat((z_other_invar, z_invar), dim=-1)
         | 
| 383 | 
            +
                            .flatten(1)
         | 
| 384 | 
            +
                            .unsqueeze(1)
         | 
| 385 | 
            +
                            .expand(B, num_rays, latent_dim * 3)
         | 
| 386 | 
            +
                        )  # [B, num_rays, latent_dim * 3]
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                        return directional_input, conditioning_input
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    if equivariance == "SO3":
         | 
| 391 | 
            +
                        z = self.vn_proj_in(Z)  # [B, latent_dim, 1, 3]
         | 
| 392 | 
            +
                        z_invar = self.vn_invar(z)  # [B, latent_dim, 3]
         | 
| 393 | 
            +
                        conditioning_input = (
         | 
| 394 | 
            +
                            z_invar.flatten(1).unsqueeze(1).expand(B, num_rays, latent_dim)
         | 
| 395 | 
            +
                        )  # [B, num_rays, latent_dim * 3]
         | 
| 396 | 
            +
                        # D [B, num_rays, 3] -> [B, num_rays, 1, 3]
         | 
| 397 | 
            +
                        # Z [B, latent_dim, 3] -> [B, 1, latent_dim, 3]
         | 
| 398 | 
            +
                        innerprod = torch.sum(
         | 
| 399 | 
            +
                            Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
         | 
| 400 | 
            +
                        )  # [B, num_rays, latent_dim]
         | 
| 401 | 
            +
                        return innerprod, conditioning_input
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                def gram_matrix_invariance(
         | 
| 404 | 
            +
                    self,
         | 
| 405 | 
            +
                    Z: Float[Tensor, "B latent_dim 3"],
         | 
| 406 | 
            +
                    D: Float[Tensor, "B num_rays 3"],
         | 
| 407 | 
            +
                    equivariance: Literal["None", "SO2", "SO3"] = "SO2",
         | 
| 408 | 
            +
                    axis_of_invariance: int = 1,
         | 
| 409 | 
            +
                ):
         | 
| 410 | 
            +
                    """Generates an invariant representation from latent code Z and direction coordinates D.
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    Args:
         | 
| 413 | 
            +
                        Z (torch.Tensor): Latent code (B x latent_dim x 3)
         | 
| 414 | 
            +
                        D (torch.Tensor): Direction coordinates (B x num_rays x 3)
         | 
| 415 | 
            +
                        equivariance (str): Type of equivariance to use. Options are 'none', 'SO2', and 'SO3'
         | 
| 416 | 
            +
                        axis_of_invariance (int): The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis).
         | 
| 417 | 
            +
                            Default is 1 (y-axis).
         | 
| 418 | 
            +
                    Returns:
         | 
| 419 | 
            +
                        torch.Tensor: Invariant representation
         | 
| 420 | 
            +
                    """
         | 
| 421 | 
            +
                    assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2."
         | 
| 422 | 
            +
                    other_axes = [i for i in range(3) if i != axis_of_invariance]
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    B, latent_dim, _ = Z.shape
         | 
| 425 | 
            +
                    _, num_rays, _ = D.shape
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                    if equivariance == "None":
         | 
| 428 | 
            +
                        # get inner product between latent code and direction coordinates
         | 
| 429 | 
            +
                        innerprod = torch.sum(
         | 
| 430 | 
            +
                            Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
         | 
| 431 | 
            +
                        )  # [B, num_rays, latent_dim]
         | 
| 432 | 
            +
                        z_input = (
         | 
| 433 | 
            +
                            Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3)
         | 
| 434 | 
            +
                        )  # [B, num_rays, latent_dim * 3]
         | 
| 435 | 
            +
                        return innerprod, z_input
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    if equivariance == "SO2":
         | 
| 438 | 
            +
                        # Select components along axes orthogonal to the axis of invariance
         | 
| 439 | 
            +
                        z_other = torch.stack(
         | 
| 440 | 
            +
                            (Z[..., other_axes[0]], Z[..., other_axes[1]]), -1
         | 
| 441 | 
            +
                        )  # [B, latent_dim, 2]
         | 
| 442 | 
            +
                        d_other = torch.stack(
         | 
| 443 | 
            +
                            (D[..., other_axes[0]], D[..., other_axes[1]]), -1
         | 
| 444 | 
            +
                        ).unsqueeze(2)  # [B, num_rays, 1, 2]
         | 
| 445 | 
            +
                        d_other = d_other.expand(
         | 
| 446 | 
            +
                            B, num_rays, latent_dim, 2
         | 
| 447 | 
            +
                        )  # size becomes [B, num_rays, latent_dim, 2]
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                        # Invariant representation of Z, gram matrix G=Z*Z' is size num_rays x latent_dim x latent_dim
         | 
| 450 | 
            +
                        G = torch.bmm(z_other, torch.transpose(z_other, 1, 2))
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                        # Flatten G to be size B x latent_dim^2
         | 
| 453 | 
            +
                        z_other_invar = G.flatten(start_dim=1)
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                        # Get invariant component of Z along the axis of invariance
         | 
| 456 | 
            +
                        z_invar = Z[..., axis_of_invariance]  # [B, latent_dim]
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                        # Innerprod is size num_rays x latent_dim
         | 
| 459 | 
            +
                        innerprod = (z_other.unsqueeze(1) * d_other).sum(
         | 
| 460 | 
            +
                            dim=-1
         | 
| 461 | 
            +
                        )  # [B, num_rays, latent_dim]
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                        # Compute norm along the axes orthogonal to the axis of invariance
         | 
| 464 | 
            +
                        d_other_norm = torch.sqrt(
         | 
| 465 | 
            +
                            D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2
         | 
| 466 | 
            +
                        ).unsqueeze(-1)  # [B, num_rays, 1]
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                        # Get invariant component of D along the axis of invariance
         | 
| 469 | 
            +
                        d_invar = D[..., axis_of_invariance].unsqueeze(-1)  # [B, num_rays, 1]
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                        if not self.old_implementation:
         | 
| 472 | 
            +
                            directional_input = torch.cat(
         | 
| 473 | 
            +
                                (innerprod, d_invar, d_other_norm), -1
         | 
| 474 | 
            +
                            )  # [B, num_rays, latent_dim + 2]
         | 
| 475 | 
            +
                            conditioning_input = (
         | 
| 476 | 
            +
                                torch.cat((z_other_invar, z_invar), -1)
         | 
| 477 | 
            +
                                .unsqueeze(1)
         | 
| 478 | 
            +
                                .expand(B, num_rays, latent_dim * 3)
         | 
| 479 | 
            +
                            )  # [B, num_rays, latent_dim^2 + latent_dim]
         | 
| 480 | 
            +
                        else:
         | 
| 481 | 
            +
                            # this is matching the previous implementation of RENI, needed if using old checkpoints
         | 
| 482 | 
            +
                            z_other_invar = z_other_invar.unsqueeze(1).expand(B, num_rays, -1)
         | 
| 483 | 
            +
                            z_invar = z_invar.unsqueeze(1).expand(B, num_rays, -1)
         | 
| 484 | 
            +
                            return torch.cat(
         | 
| 485 | 
            +
                                (innerprod, z_other_invar, d_other_norm, z_invar, d_invar), 1
         | 
| 486 | 
            +
                            )
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                        return directional_input, conditioning_input
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    if equivariance == "SO3":
         | 
| 491 | 
            +
                        G = Z @ torch.transpose(Z, 1, 2)  # [B, latent_dim, latent_dim]
         | 
| 492 | 
            +
                        innerprod = torch.sum(
         | 
| 493 | 
            +
                            Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
         | 
| 494 | 
            +
                        )  # [B, num_rays, latent_dim]
         | 
| 495 | 
            +
                        z_invar = (
         | 
| 496 | 
            +
                            G.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, -1)
         | 
| 497 | 
            +
                        )  # [B, num_rays, latent_dim^2]
         | 
| 498 | 
            +
                        return innerprod, z_invar
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                def setup_network(self):
         | 
| 501 | 
            +
                    """Sets up the network architecture"""
         | 
| 502 | 
            +
                    base_input_dims = {
         | 
| 503 | 
            +
                        "VN": {
         | 
| 504 | 
            +
                            "None": {
         | 
| 505 | 
            +
                                "direction": self.latent_dim,
         | 
| 506 | 
            +
                                "conditioning": self.latent_dim * 3,
         | 
| 507 | 
            +
                            },
         | 
| 508 | 
            +
                            "SO2": {
         | 
| 509 | 
            +
                                "direction": self.latent_dim + 2,
         | 
| 510 | 
            +
                                "conditioning": self.latent_dim * 3,
         | 
| 511 | 
            +
                            },
         | 
| 512 | 
            +
                            "SO3": {
         | 
| 513 | 
            +
                                "direction": self.latent_dim,
         | 
| 514 | 
            +
                                "conditioning": self.latent_dim * 3,
         | 
| 515 | 
            +
                            },
         | 
| 516 | 
            +
                        },
         | 
| 517 | 
            +
                        "GramMatrix": {
         | 
| 518 | 
            +
                            "None": {
         | 
| 519 | 
            +
                                "direction": self.latent_dim,
         | 
| 520 | 
            +
                                "conditioning": self.latent_dim * 3,
         | 
| 521 | 
            +
                            },
         | 
| 522 | 
            +
                            "SO2": {
         | 
| 523 | 
            +
                                "direction": self.latent_dim + 2,
         | 
| 524 | 
            +
                                "conditioning": self.latent_dim**2 + self.latent_dim,
         | 
| 525 | 
            +
                            },
         | 
| 526 | 
            +
                            "SO3": {
         | 
| 527 | 
            +
                                "direction": self.latent_dim,
         | 
| 528 | 
            +
                                "conditioning": self.latent_dim**2,
         | 
| 529 | 
            +
                            },
         | 
| 530 | 
            +
                        },
         | 
| 531 | 
            +
                    }
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    # Extract the necessary input dimensions
         | 
| 534 | 
            +
                    input_types = ["direction", "conditioning"]
         | 
| 535 | 
            +
                    input_dims = {
         | 
| 536 | 
            +
                        key: base_input_dims[self.cfg.invariant_function][self.cfg.equivariance][
         | 
| 537 | 
            +
                            key
         | 
| 538 | 
            +
                        ]
         | 
| 539 | 
            +
                        for key in input_types
         | 
| 540 | 
            +
                    }
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                    # Helper function to create NeRF encoding
         | 
| 543 | 
            +
                    def create_nerf_encoding(in_dim):
         | 
| 544 | 
            +
                        return NeRFEncoding(
         | 
| 545 | 
            +
                            in_dim=in_dim,
         | 
| 546 | 
            +
                            num_frequencies=2,
         | 
| 547 | 
            +
                            min_freq_exp=0.0,
         | 
| 548 | 
            +
                            max_freq_exp=2.0,
         | 
| 549 | 
            +
                            include_input=True,
         | 
| 550 | 
            +
                        )
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    # Dictionary-based encoding setup
         | 
| 553 | 
            +
                    encoding_setup = {
         | 
| 554 | 
            +
                        "None": [],
         | 
| 555 | 
            +
                        "Conditioning": ["conditioning"],
         | 
| 556 | 
            +
                        "Directions": ["direction"],
         | 
| 557 | 
            +
                        "Both": ["direction", "conditioning"],
         | 
| 558 | 
            +
                    }
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                    # Setting up the required encodings
         | 
| 561 | 
            +
                    for input_type in encoding_setup.get(self.cfg.encoded_input, []):
         | 
| 562 | 
            +
                        # create self.{input_type}_encoding and update input_dims
         | 
| 563 | 
            +
                        setattr(
         | 
| 564 | 
            +
                            self,
         | 
| 565 | 
            +
                            f"{input_type}_encoding",
         | 
| 566 | 
            +
                            create_nerf_encoding(input_dims[input_type]),
         | 
| 567 | 
            +
                        )
         | 
| 568 | 
            +
                        input_dims[input_type] = getattr(
         | 
| 569 | 
            +
                            self, f"{input_type}_encoding"
         | 
| 570 | 
            +
                        ).get_out_dim()
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                    output_activation = get_activation_module(self.cfg.output_activation)
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                    network = None
         | 
| 575 | 
            +
                    if self.conditioning == "Concat":
         | 
| 576 | 
            +
                        network = Siren(
         | 
| 577 | 
            +
                            in_dim=input_dims["direction"] + input_dims["conditioning"],
         | 
| 578 | 
            +
                            hidden_layers=self.hidden_layers,
         | 
| 579 | 
            +
                            hidden_features=self.hidden_features,
         | 
| 580 | 
            +
                            out_dim=self.out_features,
         | 
| 581 | 
            +
                            outermost_linear=self.last_layer_linear,
         | 
| 582 | 
            +
                            first_omega_0=self.first_omega_0,
         | 
| 583 | 
            +
                            hidden_omega_0=self.hidden_omega_0,
         | 
| 584 | 
            +
                            out_activation=output_activation,
         | 
| 585 | 
            +
                        )
         | 
| 586 | 
            +
                    elif self.conditioning == "FiLM":
         | 
| 587 | 
            +
                        network = FiLMSiren(
         | 
| 588 | 
            +
                            in_dim=input_dims["direction"],
         | 
| 589 | 
            +
                            hidden_layers=self.hidden_layers,
         | 
| 590 | 
            +
                            hidden_features=self.hidden_features,
         | 
| 591 | 
            +
                            mapping_network_in_dim=input_dims["conditioning"],
         | 
| 592 | 
            +
                            mapping_network_layers=self.mapping_layers,
         | 
| 593 | 
            +
                            mapping_network_features=self.mapping_features,
         | 
| 594 | 
            +
                            out_dim=self.out_features,
         | 
| 595 | 
            +
                            outermost_linear=True,
         | 
| 596 | 
            +
                            out_activation=output_activation,
         | 
| 597 | 
            +
                        )
         | 
| 598 | 
            +
                    elif self.conditioning == "Attention":
         | 
| 599 | 
            +
                        # transformer where K, V is from conditioning input and Q is from pos encoded directional input
         | 
| 600 | 
            +
                        network = Decoder(
         | 
| 601 | 
            +
                            in_dim=input_dims["direction"],
         | 
| 602 | 
            +
                            conditioning_input_dim=input_dims["conditioning"],
         | 
| 603 | 
            +
                            hidden_features=self.cfg.hidden_features,
         | 
| 604 | 
            +
                            num_heads=self.cfg.num_attention_heads,
         | 
| 605 | 
            +
                            num_layers=self.cfg.num_attention_layers,
         | 
| 606 | 
            +
                            out_activation=output_activation,
         | 
| 607 | 
            +
                        )
         | 
| 608 | 
            +
                    assert network is not None, "unknown conditioning type"
         | 
| 609 | 
            +
                    return network
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                def apply_positional_encoding(self, directional_input, conditioning_input):
         | 
| 612 | 
            +
                    # conditioning on just invariant directional input
         | 
| 613 | 
            +
                    if self.cfg.encoded_input == "Conditioning":
         | 
| 614 | 
            +
                        conditioning_input = self.conditioning_encoding(
         | 
| 615 | 
            +
                            conditioning_input
         | 
| 616 | 
            +
                        )  # [num_rays, embedding_dim]
         | 
| 617 | 
            +
                    elif self.cfg.encoded_input == "Directions":
         | 
| 618 | 
            +
                        directional_input = self.direction_encoding(
         | 
| 619 | 
            +
                            directional_input
         | 
| 620 | 
            +
                        )  # [num_rays, embedding_dim]
         | 
| 621 | 
            +
                    elif self.cfg.encoded_input == "Both":
         | 
| 622 | 
            +
                        directional_input = self.direction_encoding(directional_input)
         | 
| 623 | 
            +
                        conditioning_input = self.conditioning_encoding(conditioning_input)
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                    return directional_input, conditioning_input
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                def get_outputs(
         | 
| 628 | 
            +
                    self,
         | 
| 629 | 
            +
                    rays_d: Float[Tensor, "batch num_rays 3"],  # type: ignore
         | 
| 630 | 
            +
                    latent_codes: Float[Tensor, "batch_size latent_dim 3"],  # type: ignore
         | 
| 631 | 
            +
                    rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None,  # type: ignore
         | 
| 632 | 
            +
                    scale: Optional[Float[Tensor, "batch_size"]] = None,  # type: ignore
         | 
| 633 | 
            +
                ) -> Dict[str, Tensor]:
         | 
| 634 | 
            +
                    """Returns the outputs of the field.
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                    Args:
         | 
| 637 | 
            +
                        ray_samples: [batch_size num_rays 3]
         | 
| 638 | 
            +
                        latent_codes: [batch_size, latent_dim, 3]
         | 
| 639 | 
            +
                        rotation: [batch_size, 3, 3]
         | 
| 640 | 
            +
                        scale: [batch_size]
         | 
| 641 | 
            +
                    """
         | 
| 642 | 
            +
                    if rotation is not None:
         | 
| 643 | 
            +
                        if len(rotation.shape) == 3:  # [batch_size, 3, 3]
         | 
| 644 | 
            +
                            # Expand latent_codes to match [batch_size, latent_dim, 3]
         | 
| 645 | 
            +
                            latent_codes = torch.einsum(
         | 
| 646 | 
            +
                                "bik,blk->bli",
         | 
| 647 | 
            +
                                rotation,
         | 
| 648 | 
            +
                                latent_codes,
         | 
| 649 | 
            +
                            )
         | 
| 650 | 
            +
                        else:
         | 
| 651 | 
            +
                            raise NotImplementedError(
         | 
| 652 | 
            +
                                "Unsupported rotation shape. Expected [batch_size, 3, 3]."
         | 
| 653 | 
            +
                            )
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                    B, num_rays, _ = rays_d.shape
         | 
| 656 | 
            +
                    _, latent_dim, _ = latent_codes.shape
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                    if not self.old_implementation:
         | 
| 659 | 
            +
                        directional_input, conditioning_input = self.invariant_function(
         | 
| 660 | 
            +
                            latent_codes,
         | 
| 661 | 
            +
                            rays_d,
         | 
| 662 | 
            +
                            equivariance=self.equivariance,
         | 
| 663 | 
            +
                            axis_of_invariance=self.axis_of_invariance,
         | 
| 664 | 
            +
                        )  # [B, num_rays, 3]
         | 
| 665 | 
            +
             | 
| 666 | 
            +
                        if self.cfg.positional_encoding == "NeRF":
         | 
| 667 | 
            +
                            directional_input, conditioning_input = self.apply_positional_encoding(
         | 
| 668 | 
            +
                                directional_input, conditioning_input
         | 
| 669 | 
            +
                            )
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                        if self.conditioning == "Concat":
         | 
| 672 | 
            +
                            model_outputs = self.network(
         | 
| 673 | 
            +
                                torch.cat((directional_input, conditioning_input), dim=-1).reshape(
         | 
| 674 | 
            +
                                    B * num_rays, -1
         | 
| 675 | 
            +
                                )
         | 
| 676 | 
            +
                            ).view(B, num_rays, 3)  # returns -> [B num_rays, 3]
         | 
| 677 | 
            +
                        elif self.conditioning == "FiLM":
         | 
| 678 | 
            +
                            model_outputs = self.network(
         | 
| 679 | 
            +
                                directional_input.reshape(B * num_rays, -1),
         | 
| 680 | 
            +
                                conditioning_input.reshape(B * num_rays, -1),
         | 
| 681 | 
            +
                            ).view(B, num_rays, 3)  # returns -> [B num_rays, 3]
         | 
| 682 | 
            +
                        elif self.conditioning == "Attention":
         | 
| 683 | 
            +
                            model_outputs = self.network(
         | 
| 684 | 
            +
                                directional_input.reshape(B * num_rays, -1),
         | 
| 685 | 
            +
                                conditioning_input.reshape(B * num_rays, -1),
         | 
| 686 | 
            +
                            ).view(B, num_rays, 3)  # returns -> [B num_rays, 3]
         | 
| 687 | 
            +
                    else:
         | 
| 688 | 
            +
                        # in the old implementation directions were sampled with y-up not z-up so need to swap y and z in directions
         | 
| 689 | 
            +
                        directions = torch.stack(
         | 
| 690 | 
            +
                            (rays_d[..., 0], rays_d[..., 2], rays_d[..., 1]), -1
         | 
| 691 | 
            +
                        )
         | 
| 692 | 
            +
                        model_input = self.invariant_function(
         | 
| 693 | 
            +
                            latent_codes,
         | 
| 694 | 
            +
                            directions,
         | 
| 695 | 
            +
                            equivariance=self.equivariance,
         | 
| 696 | 
            +
                            axis_of_invariance=self.axis_of_invariance,
         | 
| 697 | 
            +
                        )  # [B, num_rays, 3]
         | 
| 698 | 
            +
             | 
| 699 | 
            +
                        model_outputs = self.network(model_input.view(B * num_rays, -1)).view(
         | 
| 700 | 
            +
                            B, num_rays, 3
         | 
| 701 | 
            +
                        )
         | 
| 702 | 
            +
             | 
| 703 | 
            +
                    outputs = {}
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                    if scale is not None:
         | 
| 706 | 
            +
                        scale = trunc_exp(scale)  # [num_rays] exp to ensure positive
         | 
| 707 | 
            +
                        model_outputs = model_outputs * scale.view(-1, 1, 1)  # [num_rays, 3]
         | 
| 708 | 
            +
             | 
| 709 | 
            +
                    outputs["rgb"] = model_outputs
         | 
| 710 | 
            +
             | 
| 711 | 
            +
                    return outputs
         | 
| 712 | 
            +
             | 
| 713 | 
            +
                def forward(
         | 
| 714 | 
            +
                    self,
         | 
| 715 | 
            +
                    rays_d: Float[Tensor, "batch num_rays 3"],  # type: ignore
         | 
| 716 | 
            +
                    latent_codes: Float[Tensor, "batch_size latent_dim 3"],  # type: ignore
         | 
| 717 | 
            +
                    rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None,  # type: ignore
         | 
| 718 | 
            +
                    scale: Optional[Float[Tensor, "batch_size"]] = None,  # type: ignore
         | 
| 719 | 
            +
                ) -> Dict[str, Tensor]:
         | 
| 720 | 
            +
                    """Evaluates spherical field for a given ray bundle and rotation.
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                    Args:
         | 
| 723 | 
            +
                        ray_samples: [B num_rays 3]
         | 
| 724 | 
            +
                        latent_codes: [B, num_rays, latent_dim, 3]
         | 
| 725 | 
            +
                        rotation: [batch_size, 3, 3]
         | 
| 726 | 
            +
                        scale: [batch_size]
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                    Returns:
         | 
| 729 | 
            +
                        Dict[str, Tensor]: A dictionary containing the outputs of the field.
         | 
| 730 | 
            +
                    """
         | 
| 731 | 
            +
                    return self.get_outputs(
         | 
| 732 | 
            +
                        rays_d=rays_d,
         | 
| 733 | 
            +
                        latent_codes=latent_codes,
         | 
| 734 | 
            +
                        rotation=rotation,
         | 
| 735 | 
            +
                        scale=scale,
         | 
| 736 | 
            +
                    )
         | 
    	
        spar3d/models/image_estimator/clip_based_estimator.py
    ADDED
    
    | @@ -0,0 +1,184 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass, field
         | 
| 2 | 
            +
            from typing import Any, List, Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import alpha_clip
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            from jaxtyping import Float
         | 
| 8 | 
            +
            from torch import Tensor
         | 
| 9 | 
            +
            from torchvision.transforms import Normalize
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from spar3d.models.network import get_activation
         | 
| 12 | 
            +
            from spar3d.models.utils import BaseModule
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
         | 
| 15 | 
            +
            OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            @dataclass
         | 
| 19 | 
            +
            class HeadSpec:
         | 
| 20 | 
            +
                name: str
         | 
| 21 | 
            +
                out_channels: int
         | 
| 22 | 
            +
                n_hidden_layers: int
         | 
| 23 | 
            +
                output_activation: Optional[str] = None
         | 
| 24 | 
            +
                output_bias: float = 0.0
         | 
| 25 | 
            +
                add_to_decoder_features: bool = False
         | 
| 26 | 
            +
                shape: Optional[list[int]] = None
         | 
| 27 | 
            +
                distribution_eval: str = "sample"
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class ClipBasedHeadEstimator(BaseModule):
         | 
| 31 | 
            +
                @dataclass
         | 
| 32 | 
            +
                class Config(BaseModule.Config):
         | 
| 33 | 
            +
                    model: str = "ViT-L/14@336px"
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    distribution: str = "beta"
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    # ["mean", "mode", "sample", "sample_mean"]
         | 
| 38 | 
            +
                    distribution_eval: str = "mode"
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    activation: str = "relu"
         | 
| 41 | 
            +
                    hidden_features: int = 512
         | 
| 42 | 
            +
                    heads: List[HeadSpec] = field(default_factory=lambda: [])
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                cfg: Config
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def configure(self):
         | 
| 47 | 
            +
                    self.model, _ = alpha_clip.load(
         | 
| 48 | 
            +
                        self.cfg.model,
         | 
| 49 | 
            +
                    )  # change to your own ckpt path
         | 
| 50 | 
            +
                    self.model.eval()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    if not hasattr(self.model.visual, "input_resolution"):
         | 
| 53 | 
            +
                        self.img_size = 224
         | 
| 54 | 
            +
                    else:
         | 
| 55 | 
            +
                        self.img_size = self.model.visual.input_resolution
         | 
| 56 | 
            +
                        # Check if img_size is subscribable and pick the first element
         | 
| 57 | 
            +
                        if hasattr(self.img_size, "__getitem__"):
         | 
| 58 | 
            +
                            self.img_size = self.img_size[0]
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    # Do not add the weights in self.model to the optimizer
         | 
| 61 | 
            +
                    for param in self.model.parameters():
         | 
| 62 | 
            +
                        param.requires_grad = False
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    assert len(self.cfg.heads) > 0
         | 
| 65 | 
            +
                    heads = {}
         | 
| 66 | 
            +
                    for head in self.cfg.heads:
         | 
| 67 | 
            +
                        head_layers = []
         | 
| 68 | 
            +
                        in_feature = self.model.visual.output_dim
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                        for i in range(head.n_hidden_layers):
         | 
| 71 | 
            +
                            head_layers += [
         | 
| 72 | 
            +
                                nn.Linear(
         | 
| 73 | 
            +
                                    in_feature if i == 0 else self.cfg.hidden_features,
         | 
| 74 | 
            +
                                    self.cfg.hidden_features,
         | 
| 75 | 
            +
                                ),
         | 
| 76 | 
            +
                                self.make_activation(self.cfg.activation),
         | 
| 77 | 
            +
                            ]
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                        head_layers = [nn.Sequential(*head_layers)]
         | 
| 80 | 
            +
                        head_layers += [
         | 
| 81 | 
            +
                            nn.Sequential(
         | 
| 82 | 
            +
                                nn.Linear(
         | 
| 83 | 
            +
                                    self.cfg.hidden_features,
         | 
| 84 | 
            +
                                    self.cfg.hidden_features,
         | 
| 85 | 
            +
                                ),
         | 
| 86 | 
            +
                                self.make_activation(self.cfg.activation),
         | 
| 87 | 
            +
                                nn.Linear(self.cfg.hidden_features, 1),
         | 
| 88 | 
            +
                            )
         | 
| 89 | 
            +
                            for _ in range(2)
         | 
| 90 | 
            +
                        ]
         | 
| 91 | 
            +
                        heads[head.name] = nn.ModuleList(head_layers)
         | 
| 92 | 
            +
                    self.heads = nn.ModuleDict(heads)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def make_activation(self, activation):
         | 
| 95 | 
            +
                    if activation == "relu":
         | 
| 96 | 
            +
                        return nn.ReLU(inplace=True)
         | 
| 97 | 
            +
                    elif activation == "silu":
         | 
| 98 | 
            +
                        return nn.SiLU(inplace=True)
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        raise NotImplementedError
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def forward(
         | 
| 103 | 
            +
                    self,
         | 
| 104 | 
            +
                    cond_image: Float[Tensor, "B 1 H W 4"],
         | 
| 105 | 
            +
                    sample: bool = True,
         | 
| 106 | 
            +
                ) -> dict[str, Any]:
         | 
| 107 | 
            +
                    # Run the model
         | 
| 108 | 
            +
                    # Resize cond_image to 224
         | 
| 109 | 
            +
                    cond_image = cond_image.flatten(0, 1)
         | 
| 110 | 
            +
                    cond_image = nn.functional.interpolate(
         | 
| 111 | 
            +
                        cond_image.permute(0, 3, 1, 2),
         | 
| 112 | 
            +
                        size=(self.img_size, self.img_size),
         | 
| 113 | 
            +
                        mode="bilinear",
         | 
| 114 | 
            +
                        align_corners=False,
         | 
| 115 | 
            +
                    )
         | 
| 116 | 
            +
                    mask = cond_image[:, 3:4]
         | 
| 117 | 
            +
                    cond_image = cond_image[:, :3] * mask
         | 
| 118 | 
            +
                    cond_image = Normalize(
         | 
| 119 | 
            +
                        mean=OPENAI_DATASET_MEAN,
         | 
| 120 | 
            +
                        std=OPENAI_DATASET_STD,
         | 
| 121 | 
            +
                    )(cond_image)
         | 
| 122 | 
            +
                    mask = Normalize(0.5, 0.26)(mask).half()
         | 
| 123 | 
            +
                    image_features = self.model.visual(cond_image.half(), mask).float()
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # Run the heads
         | 
| 126 | 
            +
                    outputs = {}
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    for head_dict in self.cfg.heads:
         | 
| 129 | 
            +
                        head_name = head_dict.name
         | 
| 130 | 
            +
                        shared_head, d1_h, d2_h = self.heads[head_name]
         | 
| 131 | 
            +
                        shared_features = shared_head(image_features)
         | 
| 132 | 
            +
                        d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
         | 
| 133 | 
            +
                        if self.cfg.distribution == "normal":
         | 
| 134 | 
            +
                            mean = d1
         | 
| 135 | 
            +
                            var = d2
         | 
| 136 | 
            +
                            if mean.shape[-1] == 1:
         | 
| 137 | 
            +
                                outputs[head_name] = torch.distributions.Normal(
         | 
| 138 | 
            +
                                    mean + head_dict.output_bias,
         | 
| 139 | 
            +
                                    torch.nn.functional.softplus(var),
         | 
| 140 | 
            +
                                )
         | 
| 141 | 
            +
                            else:
         | 
| 142 | 
            +
                                outputs[head_name] = torch.distributions.MultivariateNormal(
         | 
| 143 | 
            +
                                    mean + head_dict.output_bias,
         | 
| 144 | 
            +
                                    torch.nn.functional.softplus(var).diag_embed(),
         | 
| 145 | 
            +
                                )
         | 
| 146 | 
            +
                        elif self.cfg.distribution == "beta":
         | 
| 147 | 
            +
                            outputs[head_name] = torch.distributions.Beta(
         | 
| 148 | 
            +
                                torch.nn.functional.softplus(d1 + head_dict.output_bias),
         | 
| 149 | 
            +
                                torch.nn.functional.softplus(d2 + head_dict.output_bias),
         | 
| 150 | 
            +
                            )
         | 
| 151 | 
            +
                        else:
         | 
| 152 | 
            +
                            raise NotImplementedError
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    if sample:
         | 
| 155 | 
            +
                        for head_dict in self.cfg.heads:
         | 
| 156 | 
            +
                            head_name = head_dict.name
         | 
| 157 | 
            +
                            dist = outputs[head_name]
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                            if head_dict.distribution_eval == "mean":
         | 
| 160 | 
            +
                                out = dist.mean
         | 
| 161 | 
            +
                            elif head_dict.distribution_eval == "mode":
         | 
| 162 | 
            +
                                out = dist.mode
         | 
| 163 | 
            +
                            elif head_dict.distribution_eval == "sample_mean":
         | 
| 164 | 
            +
                                out = dist.sample([10]).mean(-1)
         | 
| 165 | 
            +
                            else:
         | 
| 166 | 
            +
                                # use rsample if gradient is needed
         | 
| 167 | 
            +
                                out = dist.rsample() if self.training else dist.sample()
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                            outputs[head_name] = get_activation(head_dict.output_activation)(out)
         | 
| 170 | 
            +
                            outputs[f"{head_name}_dist"] = dist
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    for head in self.cfg.heads:
         | 
| 173 | 
            +
                        if head.shape:
         | 
| 174 | 
            +
                            if not sample:
         | 
| 175 | 
            +
                                raise ValueError(
         | 
| 176 | 
            +
                                    "Cannot reshape non-sampled probabilisitic outputs"
         | 
| 177 | 
            +
                                )
         | 
| 178 | 
            +
                            outputs[head.name] = outputs[head.name].reshape(*head.shape)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                        if head.add_to_decoder_features:
         | 
| 181 | 
            +
                            outputs[f"decoder_{head.name}"] = outputs[head.name]
         | 
| 182 | 
            +
                            del outputs[head.name]
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    return outputs
         | 
    	
        spar3d/models/isosurface.py
    ADDED
    
    | @@ -0,0 +1,229 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Optional, Tuple
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from jaxtyping import Float, Integer
         | 
| 7 | 
            +
            from torch import Tensor
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .mesh import Mesh
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class IsosurfaceHelper(nn.Module):
         | 
| 13 | 
            +
                points_range: Tuple[float, float] = (0, 1)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                @property
         | 
| 16 | 
            +
                def grid_vertices(self) -> Float[Tensor, "N 3"]:
         | 
| 17 | 
            +
                    raise NotImplementedError
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                @property
         | 
| 20 | 
            +
                def requires_instance_per_batch(self) -> bool:
         | 
| 21 | 
            +
                    return False
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class MarchingTetrahedraHelper(IsosurfaceHelper):
         | 
| 25 | 
            +
                def __init__(self, resolution: int, tets_path: str):
         | 
| 26 | 
            +
                    super().__init__()
         | 
| 27 | 
            +
                    self.resolution = resolution
         | 
| 28 | 
            +
                    self.tets_path = tets_path
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    self.triangle_table: Float[Tensor, "..."]
         | 
| 31 | 
            +
                    self.register_buffer(
         | 
| 32 | 
            +
                        "triangle_table",
         | 
| 33 | 
            +
                        torch.as_tensor(
         | 
| 34 | 
            +
                            [
         | 
| 35 | 
            +
                                [-1, -1, -1, -1, -1, -1],
         | 
| 36 | 
            +
                                [1, 0, 2, -1, -1, -1],
         | 
| 37 | 
            +
                                [4, 0, 3, -1, -1, -1],
         | 
| 38 | 
            +
                                [1, 4, 2, 1, 3, 4],
         | 
| 39 | 
            +
                                [3, 1, 5, -1, -1, -1],
         | 
| 40 | 
            +
                                [2, 3, 0, 2, 5, 3],
         | 
| 41 | 
            +
                                [1, 4, 0, 1, 5, 4],
         | 
| 42 | 
            +
                                [4, 2, 5, -1, -1, -1],
         | 
| 43 | 
            +
                                [4, 5, 2, -1, -1, -1],
         | 
| 44 | 
            +
                                [4, 1, 0, 4, 5, 1],
         | 
| 45 | 
            +
                                [3, 2, 0, 3, 5, 2],
         | 
| 46 | 
            +
                                [1, 3, 5, -1, -1, -1],
         | 
| 47 | 
            +
                                [4, 1, 2, 4, 3, 1],
         | 
| 48 | 
            +
                                [3, 0, 4, -1, -1, -1],
         | 
| 49 | 
            +
                                [2, 0, 1, -1, -1, -1],
         | 
| 50 | 
            +
                                [-1, -1, -1, -1, -1, -1],
         | 
| 51 | 
            +
                            ],
         | 
| 52 | 
            +
                            dtype=torch.long,
         | 
| 53 | 
            +
                        ),
         | 
| 54 | 
            +
                        persistent=False,
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                    self.num_triangles_table: Integer[Tensor, "..."]
         | 
| 57 | 
            +
                    self.register_buffer(
         | 
| 58 | 
            +
                        "num_triangles_table",
         | 
| 59 | 
            +
                        torch.as_tensor(
         | 
| 60 | 
            +
                            [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
         | 
| 61 | 
            +
                        ),
         | 
| 62 | 
            +
                        persistent=False,
         | 
| 63 | 
            +
                    )
         | 
| 64 | 
            +
                    self.base_tet_edges: Integer[Tensor, "..."]
         | 
| 65 | 
            +
                    self.register_buffer(
         | 
| 66 | 
            +
                        "base_tet_edges",
         | 
| 67 | 
            +
                        torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
         | 
| 68 | 
            +
                        persistent=False,
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    tets = np.load(self.tets_path)
         | 
| 72 | 
            +
                    self._grid_vertices: Float[Tensor, "..."]
         | 
| 73 | 
            +
                    self.register_buffer(
         | 
| 74 | 
            +
                        "_grid_vertices",
         | 
| 75 | 
            +
                        torch.from_numpy(tets["vertices"]).float(),
         | 
| 76 | 
            +
                        persistent=False,
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
                    self.indices: Integer[Tensor, "..."]
         | 
| 79 | 
            +
                    self.register_buffer(
         | 
| 80 | 
            +
                        "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    center_indices, boundary_indices = self.get_center_boundary_index(
         | 
| 86 | 
            +
                        self._grid_vertices
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
                    self.center_indices: Integer[Tensor, "..."]
         | 
| 89 | 
            +
                    self.register_buffer("center_indices", center_indices, persistent=False)
         | 
| 90 | 
            +
                    self.boundary_indices: Integer[Tensor, "..."]
         | 
| 91 | 
            +
                    self.register_buffer("boundary_indices", boundary_indices, persistent=False)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def get_center_boundary_index(self, verts):
         | 
| 94 | 
            +
                    magn = torch.sum(verts**2, dim=-1)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    center_idx = torch.argmin(magn)
         | 
| 97 | 
            +
                    boundary_neg = verts == verts.max()
         | 
| 98 | 
            +
                    boundary_pos = verts == verts.min()
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    boundary = torch.bitwise_or(boundary_pos, boundary_neg)
         | 
| 101 | 
            +
                    boundary = torch.sum(boundary.float(), dim=-1)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    boundary_idx = torch.nonzero(boundary)
         | 
| 104 | 
            +
                    return center_idx, boundary_idx.squeeze(dim=-1)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def normalize_grid_deformation(
         | 
| 107 | 
            +
                    self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
         | 
| 108 | 
            +
                ) -> Float[Tensor, "Nv 3"]:
         | 
| 109 | 
            +
                    return (
         | 
| 110 | 
            +
                        (self.points_range[1] - self.points_range[0])
         | 
| 111 | 
            +
                        / self.resolution  # half tet size is approximately 1 / self.resolution
         | 
| 112 | 
            +
                        * torch.tanh(grid_vertex_offsets)
         | 
| 113 | 
            +
                    )  # FIXME: hard-coded activation
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                @property
         | 
| 116 | 
            +
                def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
         | 
| 117 | 
            +
                    return self._grid_vertices
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                @property
         | 
| 120 | 
            +
                def all_edges(self) -> Integer[Tensor, "Ne 2"]:
         | 
| 121 | 
            +
                    if self._all_edges is None:
         | 
| 122 | 
            +
                        # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
         | 
| 123 | 
            +
                        edges = torch.tensor(
         | 
| 124 | 
            +
                            [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
         | 
| 125 | 
            +
                            dtype=torch.long,
         | 
| 126 | 
            +
                            device=self.indices.device,
         | 
| 127 | 
            +
                        )
         | 
| 128 | 
            +
                        _all_edges = self.indices[:, edges].reshape(-1, 2)
         | 
| 129 | 
            +
                        _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
         | 
| 130 | 
            +
                        _all_edges = torch.unique(_all_edges_sorted, dim=0)
         | 
| 131 | 
            +
                        self._all_edges = _all_edges
         | 
| 132 | 
            +
                    return self._all_edges
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def sort_edges(self, edges_ex2):
         | 
| 135 | 
            +
                    with torch.no_grad():
         | 
| 136 | 
            +
                        order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
         | 
| 137 | 
            +
                        order = order.unsqueeze(dim=1)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        a = torch.gather(input=edges_ex2, index=order, dim=1)
         | 
| 140 | 
            +
                        b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    return torch.stack([a, b], -1)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def _forward(self, pos_nx3, sdf_n, tet_fx4):
         | 
| 145 | 
            +
                    with torch.no_grad():
         | 
| 146 | 
            +
                        occ_n = sdf_n > 0
         | 
| 147 | 
            +
                        occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
         | 
| 148 | 
            +
                        occ_sum = torch.sum(occ_fx4, -1)
         | 
| 149 | 
            +
                        valid_tets = (occ_sum > 0) & (occ_sum < 4)
         | 
| 150 | 
            +
                        occ_sum = occ_sum[valid_tets]
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                        # find all vertices
         | 
| 153 | 
            +
                        all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
         | 
| 154 | 
            +
                        all_edges = self.sort_edges(all_edges)
         | 
| 155 | 
            +
                        unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                        unique_edges = unique_edges.long()
         | 
| 158 | 
            +
                        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
         | 
| 159 | 
            +
                        mapping = (
         | 
| 160 | 
            +
                            torch.ones(
         | 
| 161 | 
            +
                                (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
         | 
| 162 | 
            +
                            )
         | 
| 163 | 
            +
                            * -1
         | 
| 164 | 
            +
                        )
         | 
| 165 | 
            +
                        mapping[mask_edges] = torch.arange(
         | 
| 166 | 
            +
                            mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
         | 
| 167 | 
            +
                        )
         | 
| 168 | 
            +
                        idx_map = mapping[idx_map]  # map edges to verts
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                        interp_v = unique_edges[mask_edges]
         | 
| 171 | 
            +
                    edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
         | 
| 172 | 
            +
                    edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
         | 
| 173 | 
            +
                    edges_to_interp_sdf[:, -1] *= -1
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    denominator = edges_to_interp_sdf.sum(1, keepdim=True)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
         | 
| 178 | 
            +
                    verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    idx_map = idx_map.reshape(-1, 6)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
         | 
| 183 | 
            +
                    tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
         | 
| 184 | 
            +
                    num_triangles = self.num_triangles_table[tetindex]
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    # Generate triangle indices
         | 
| 187 | 
            +
                    faces = torch.cat(
         | 
| 188 | 
            +
                        (
         | 
| 189 | 
            +
                            torch.gather(
         | 
| 190 | 
            +
                                input=idx_map[num_triangles == 1],
         | 
| 191 | 
            +
                                dim=1,
         | 
| 192 | 
            +
                                index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
         | 
| 193 | 
            +
                            ).reshape(-1, 3),
         | 
| 194 | 
            +
                            torch.gather(
         | 
| 195 | 
            +
                                input=idx_map[num_triangles == 2],
         | 
| 196 | 
            +
                                dim=1,
         | 
| 197 | 
            +
                                index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
         | 
| 198 | 
            +
                            ).reshape(-1, 3),
         | 
| 199 | 
            +
                        ),
         | 
| 200 | 
            +
                        dim=0,
         | 
| 201 | 
            +
                    )
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    return verts, faces
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def forward(
         | 
| 206 | 
            +
                    self,
         | 
| 207 | 
            +
                    level: Float[Tensor, "N3 1"],
         | 
| 208 | 
            +
                    deformation: Optional[Float[Tensor, "N3 3"]] = None,
         | 
| 209 | 
            +
                ) -> Mesh:
         | 
| 210 | 
            +
                    if deformation is not None:
         | 
| 211 | 
            +
                        grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
         | 
| 212 | 
            +
                            deformation
         | 
| 213 | 
            +
                        )
         | 
| 214 | 
            +
                    else:
         | 
| 215 | 
            +
                        grid_vertices = self.grid_vertices
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    mesh = Mesh(
         | 
| 220 | 
            +
                        v_pos=v_pos,
         | 
| 221 | 
            +
                        t_pos_idx=t_pos_idx,
         | 
| 222 | 
            +
                        # extras
         | 
| 223 | 
            +
                        grid_vertices=grid_vertices,
         | 
| 224 | 
            +
                        tet_edges=self.all_edges,
         | 
| 225 | 
            +
                        grid_level=level,
         | 
| 226 | 
            +
                        grid_deformation=deformation,
         | 
| 227 | 
            +
                    )
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    return mesh
         | 
    	
        spar3d/models/mesh.py
    ADDED
    
    | @@ -0,0 +1,317 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import annotations
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            from typing import Any, Dict, Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
            import trimesh
         | 
| 10 | 
            +
            from jaxtyping import Float, Integer
         | 
| 11 | 
            +
            from torch import Tensor
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from spar3d.models.utils import dot
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            try:
         | 
| 16 | 
            +
                from uv_unwrapper import Unwrapper
         | 
| 17 | 
            +
            except ImportError:
         | 
| 18 | 
            +
                import logging
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                logging.warning(
         | 
| 21 | 
            +
                    "Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`"
         | 
| 22 | 
            +
                )
         | 
| 23 | 
            +
                # Exit early to avoid further errors
         | 
| 24 | 
            +
                raise ImportError("uv_unwrapper not found")
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            try:
         | 
| 27 | 
            +
                import gpytoolbox
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                TRIANGLE_REMESH_AVAILABLE = True
         | 
| 30 | 
            +
            except ImportError:
         | 
| 31 | 
            +
                TRIANGLE_REMESH_AVAILABLE = False
         | 
| 32 | 
            +
                import logging
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                logging.warning(
         | 
| 35 | 
            +
                    "Could not import gpytoolbox. Triangle remeshing functionality will be disabled. "
         | 
| 36 | 
            +
                    "Install via `pip install gpytoolbox`"
         | 
| 37 | 
            +
                )
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            try:
         | 
| 40 | 
            +
                import pynim
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                QUAD_REMESH_AVAILABLE = True
         | 
| 43 | 
            +
            except ImportError:
         | 
| 44 | 
            +
                QUAD_REMESH_AVAILABLE = False
         | 
| 45 | 
            +
                import logging
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                logging.warning(
         | 
| 48 | 
            +
                    "Could not import pynim. Quad remeshing functionality will be disabled. "
         | 
| 49 | 
            +
                    "Install via `pip install git+https://github.com/vork/[email protected]`"
         | 
| 50 | 
            +
                )
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class Mesh:
         | 
| 54 | 
            +
                def __init__(
         | 
| 55 | 
            +
                    self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
         | 
| 56 | 
            +
                ) -> None:
         | 
| 57 | 
            +
                    self.v_pos: Float[Tensor, "Nv 3"] = v_pos
         | 
| 58 | 
            +
                    self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
         | 
| 59 | 
            +
                    self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
         | 
| 60 | 
            +
                    self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
         | 
| 61 | 
            +
                    self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
         | 
| 62 | 
            +
                    self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
         | 
| 63 | 
            +
                    self.extras: Dict[str, Any] = {}
         | 
| 64 | 
            +
                    for k, v in kwargs.items():
         | 
| 65 | 
            +
                        self.add_extra(k, v)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    self.unwrapper = Unwrapper()
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def add_extra(self, k, v) -> None:
         | 
| 70 | 
            +
                    self.extras[k] = v
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                @property
         | 
| 73 | 
            +
                def requires_grad(self):
         | 
| 74 | 
            +
                    return self.v_pos.requires_grad
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                @property
         | 
| 77 | 
            +
                def v_nrm(self):
         | 
| 78 | 
            +
                    if self._v_nrm is None:
         | 
| 79 | 
            +
                        self._v_nrm = self._compute_vertex_normal()
         | 
| 80 | 
            +
                    return self._v_nrm
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                @property
         | 
| 83 | 
            +
                def v_tng(self):
         | 
| 84 | 
            +
                    if self._v_tng is None:
         | 
| 85 | 
            +
                        self._v_tng = self._compute_vertex_tangent()
         | 
| 86 | 
            +
                    return self._v_tng
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                @property
         | 
| 89 | 
            +
                def v_tex(self):
         | 
| 90 | 
            +
                    if self._v_tex is None:
         | 
| 91 | 
            +
                        self.unwrap_uv()
         | 
| 92 | 
            +
                    return self._v_tex
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                @property
         | 
| 95 | 
            +
                def edges(self):
         | 
| 96 | 
            +
                    if self._edges is None:
         | 
| 97 | 
            +
                        self._edges = self._compute_edges()
         | 
| 98 | 
            +
                    return self._edges
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def _compute_vertex_normal(self):
         | 
| 101 | 
            +
                    i0 = self.t_pos_idx[:, 0]
         | 
| 102 | 
            +
                    i1 = self.t_pos_idx[:, 1]
         | 
| 103 | 
            +
                    i2 = self.t_pos_idx[:, 2]
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    v0 = self.v_pos[i0, :]
         | 
| 106 | 
            +
                    v1 = self.v_pos[i1, :]
         | 
| 107 | 
            +
                    v2 = self.v_pos[i2, :]
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # Splat face normals to vertices
         | 
| 112 | 
            +
                    v_nrm = torch.zeros_like(self.v_pos)
         | 
| 113 | 
            +
                    v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
         | 
| 114 | 
            +
                    v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
         | 
| 115 | 
            +
                    v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    # Normalize, replace zero (degenerated) normals with some default value
         | 
| 118 | 
            +
                    v_nrm = torch.where(
         | 
| 119 | 
            +
                        dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
                    v_nrm = F.normalize(v_nrm, dim=1)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    if torch.is_anomaly_enabled():
         | 
| 124 | 
            +
                        assert torch.all(torch.isfinite(v_nrm))
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    return v_nrm
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                def _compute_vertex_tangent(self):
         | 
| 129 | 
            +
                    vn_idx = [None] * 3
         | 
| 130 | 
            +
                    pos = [None] * 3
         | 
| 131 | 
            +
                    tex = [None] * 3
         | 
| 132 | 
            +
                    for i in range(0, 3):
         | 
| 133 | 
            +
                        pos[i] = self.v_pos[self.t_pos_idx[:, i]]
         | 
| 134 | 
            +
                        tex[i] = self.v_tex[self.t_pos_idx[:, i]]
         | 
| 135 | 
            +
                        # t_nrm_idx is always the same as t_pos_idx
         | 
| 136 | 
            +
                        vn_idx[i] = self.t_pos_idx[:, i]
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    tangents = torch.zeros_like(self.v_nrm)
         | 
| 139 | 
            +
                    tansum = torch.zeros_like(self.v_nrm)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # Compute tangent space for each triangle
         | 
| 142 | 
            +
                    duv1 = tex[1] - tex[0]
         | 
| 143 | 
            +
                    duv2 = tex[2] - tex[0]
         | 
| 144 | 
            +
                    dpos1 = pos[1] - pos[0]
         | 
| 145 | 
            +
                    dpos2 = pos[2] - pos[0]
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    # Avoid division by zero for degenerated texture coordinates
         | 
| 152 | 
            +
                    denom_safe = denom.clip(1e-6)
         | 
| 153 | 
            +
                    tang = tng_nom / denom_safe
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    # Update all 3 vertices
         | 
| 156 | 
            +
                    for i in range(0, 3):
         | 
| 157 | 
            +
                        idx = vn_idx[i][:, None].repeat(1, 3)
         | 
| 158 | 
            +
                        tangents.scatter_add_(0, idx, tang)  # tangents[n_i] = tangents[n_i] + tang
         | 
| 159 | 
            +
                        tansum.scatter_add_(
         | 
| 160 | 
            +
                            0, idx, torch.ones_like(tang)
         | 
| 161 | 
            +
                        )  # tansum[n_i] = tansum[n_i] + 1
         | 
| 162 | 
            +
                    # Also normalize it. Here we do not normalize the individual triangles first so larger area
         | 
| 163 | 
            +
                    # triangles influence the tangent space more
         | 
| 164 | 
            +
                    tangents = tangents / tansum
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    # Normalize and make sure tangent is perpendicular to normal
         | 
| 167 | 
            +
                    tangents = F.normalize(tangents, dim=1)
         | 
| 168 | 
            +
                    tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    if torch.is_anomaly_enabled():
         | 
| 171 | 
            +
                        assert torch.all(torch.isfinite(tangents))
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    return tangents
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def quad_remesh(
         | 
| 176 | 
            +
                    self,
         | 
| 177 | 
            +
                    quad_vertex_count: int = -1,
         | 
| 178 | 
            +
                    quad_rosy: int = 4,
         | 
| 179 | 
            +
                    quad_crease_angle: float = -1.0,
         | 
| 180 | 
            +
                    quad_smooth_iter: int = 2,
         | 
| 181 | 
            +
                    quad_align_to_boundaries: bool = False,
         | 
| 182 | 
            +
                ) -> Mesh:
         | 
| 183 | 
            +
                    if not QUAD_REMESH_AVAILABLE:
         | 
| 184 | 
            +
                        raise ImportError("Quad remeshing requires pynim to be installed")
         | 
| 185 | 
            +
                    if quad_vertex_count < 0:
         | 
| 186 | 
            +
                        quad_vertex_count = self.v_pos.shape[0]
         | 
| 187 | 
            +
                    v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
         | 
| 188 | 
            +
                    t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    new_vert, new_faces = pynim.remesh(
         | 
| 191 | 
            +
                        v_pos,
         | 
| 192 | 
            +
                        t_pos_idx,
         | 
| 193 | 
            +
                        quad_vertex_count // 4,
         | 
| 194 | 
            +
                        rosy=quad_rosy,
         | 
| 195 | 
            +
                        posy=4,
         | 
| 196 | 
            +
                        creaseAngle=quad_crease_angle,
         | 
| 197 | 
            +
                        align_to_boundaries=quad_align_to_boundaries,
         | 
| 198 | 
            +
                        smooth_iter=quad_smooth_iter,
         | 
| 199 | 
            +
                        deterministic=False,
         | 
| 200 | 
            +
                    )
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    # Briefly load in trimesh
         | 
| 203 | 
            +
                    mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32))
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous()
         | 
| 206 | 
            +
                    t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous()
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    # Create new mesh
         | 
| 209 | 
            +
                    return Mesh(v_pos, t_pos_idx)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def triangle_remesh(
         | 
| 212 | 
            +
                    self,
         | 
| 213 | 
            +
                    triangle_average_edge_length_multiplier: Optional[float] = None,
         | 
| 214 | 
            +
                    triangle_remesh_steps: int = 10,
         | 
| 215 | 
            +
                    triangle_vertex_count=-1,
         | 
| 216 | 
            +
                ):
         | 
| 217 | 
            +
                    if not TRIANGLE_REMESH_AVAILABLE:
         | 
| 218 | 
            +
                        raise ImportError("Triangle remeshing requires gpytoolbox to be installed")
         | 
| 219 | 
            +
                    if triangle_vertex_count > 0:
         | 
| 220 | 
            +
                        reduction = triangle_vertex_count / self.v_pos.shape[0]
         | 
| 221 | 
            +
                        print("Triangle reduction:", reduction)
         | 
| 222 | 
            +
                        v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
         | 
| 223 | 
            +
                        t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
         | 
| 224 | 
            +
                        if reduction > 1.0:
         | 
| 225 | 
            +
                            subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2)))
         | 
| 226 | 
            +
                            print("Subdivide iters:", subdivide_iters)
         | 
| 227 | 
            +
                            v_pos, t_pos_idx = gpytoolbox.subdivide(
         | 
| 228 | 
            +
                                v_pos,
         | 
| 229 | 
            +
                                t_pos_idx,
         | 
| 230 | 
            +
                                iters=subdivide_iters,
         | 
| 231 | 
            +
                            )
         | 
| 232 | 
            +
                            reduction = triangle_vertex_count / v_pos.shape[0]
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                        # Simplify
         | 
| 235 | 
            +
                        points_out, faces_out, _, _ = gpytoolbox.decimate(
         | 
| 236 | 
            +
                            v_pos,
         | 
| 237 | 
            +
                            t_pos_idx,
         | 
| 238 | 
            +
                            face_ratio=reduction,
         | 
| 239 | 
            +
                        )
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                        # Convert back to torch
         | 
| 242 | 
            +
                        self.v_pos = torch.from_numpy(points_out).to(self.v_pos)
         | 
| 243 | 
            +
                        self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx)
         | 
| 244 | 
            +
                        self._edges = None
         | 
| 245 | 
            +
                        triangle_average_edge_length_multiplier = None
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    edges = self.edges
         | 
| 248 | 
            +
                    if triangle_average_edge_length_multiplier is None:
         | 
| 249 | 
            +
                        h = None
         | 
| 250 | 
            +
                    else:
         | 
| 251 | 
            +
                        h = float(
         | 
| 252 | 
            +
                            torch.linalg.norm(
         | 
| 253 | 
            +
                                self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1
         | 
| 254 | 
            +
                            )
         | 
| 255 | 
            +
                            .mean()
         | 
| 256 | 
            +
                            .item()
         | 
| 257 | 
            +
                            * triangle_average_edge_length_multiplier
         | 
| 258 | 
            +
                        )
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    # Convert to numpy
         | 
| 261 | 
            +
                    v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64)
         | 
| 262 | 
            +
                    t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    # Remesh
         | 
| 265 | 
            +
                    v_remesh, f_remesh = gpytoolbox.remesh_botsch(
         | 
| 266 | 
            +
                        v_pos,
         | 
| 267 | 
            +
                        t_pos_idx,
         | 
| 268 | 
            +
                        triangle_remesh_steps,
         | 
| 269 | 
            +
                        h,
         | 
| 270 | 
            +
                    )
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    # Convert back to torch
         | 
| 273 | 
            +
                    v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous()
         | 
| 274 | 
            +
                    t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous()
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    # Create new mesh
         | 
| 277 | 
            +
                    return Mesh(v_pos, t_pos_idx)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                @torch.no_grad()
         | 
| 280 | 
            +
                def unwrap_uv(
         | 
| 281 | 
            +
                    self,
         | 
| 282 | 
            +
                    island_padding: float = 0.02,
         | 
| 283 | 
            +
                ) -> Mesh:
         | 
| 284 | 
            +
                    uv, indices = self.unwrapper(
         | 
| 285 | 
            +
                        self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
         | 
| 286 | 
            +
                    )
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    # Do store per vertex UVs.
         | 
| 289 | 
            +
                    # This means we need to duplicate some vertices at the seams
         | 
| 290 | 
            +
                    individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
         | 
| 291 | 
            +
                    individual_faces = torch.arange(
         | 
| 292 | 
            +
                        individual_vertices.shape[0],
         | 
| 293 | 
            +
                        device=individual_vertices.device,
         | 
| 294 | 
            +
                        dtype=self.t_pos_idx.dtype,
         | 
| 295 | 
            +
                    ).reshape(-1, 3)
         | 
| 296 | 
            +
                    uv_flat = uv[indices].reshape((-1, 2))
         | 
| 297 | 
            +
                    # uv_flat[:, 1] = 1 - uv_flat[:, 1]
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    self.v_pos = individual_vertices
         | 
| 300 | 
            +
                    self.t_pos_idx = individual_faces
         | 
| 301 | 
            +
                    self._v_tex = uv_flat
         | 
| 302 | 
            +
                    self._v_nrm = self._compute_vertex_normal()
         | 
| 303 | 
            +
                    self._v_tng = self._compute_vertex_tangent()
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                def _compute_edges(self):
         | 
| 306 | 
            +
                    # Compute edges
         | 
| 307 | 
            +
                    edges = torch.cat(
         | 
| 308 | 
            +
                        [
         | 
| 309 | 
            +
                            self.t_pos_idx[:, [0, 1]],
         | 
| 310 | 
            +
                            self.t_pos_idx[:, [1, 2]],
         | 
| 311 | 
            +
                            self.t_pos_idx[:, [2, 0]],
         | 
| 312 | 
            +
                        ],
         | 
| 313 | 
            +
                        dim=0,
         | 
| 314 | 
            +
                    )
         | 
| 315 | 
            +
                    edges = edges.sort()[0]
         | 
| 316 | 
            +
                    edges = torch.unique(edges, dim=0)
         | 
| 317 | 
            +
                    return edges
         | 
    	
        spar3d/models/network.py
    ADDED
    
    | @@ -0,0 +1,223 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass, field
         | 
| 2 | 
            +
            from typing import Callable, List, Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            from einops import rearrange
         | 
| 8 | 
            +
            from jaxtyping import Float
         | 
| 9 | 
            +
            from torch import Tensor
         | 
| 10 | 
            +
            from torch.autograd import Function
         | 
| 11 | 
            +
            from torch.cuda.amp import custom_bwd, custom_fwd
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from spar3d.models.utils import BaseModule, normalize
         | 
| 14 | 
            +
            from spar3d.utils import get_device
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def conditional_decorator(decorator_with_args, condition, *args, **kwargs):
         | 
| 18 | 
            +
                def wrapper(fn):
         | 
| 19 | 
            +
                    if condition:
         | 
| 20 | 
            +
                        if len(kwargs) == 0:
         | 
| 21 | 
            +
                            return decorator_with_args
         | 
| 22 | 
            +
                        return decorator_with_args(*args, **kwargs)(fn)
         | 
| 23 | 
            +
                    else:
         | 
| 24 | 
            +
                        return fn
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                return wrapper
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class PixelShuffleUpsampleNetwork(BaseModule):
         | 
| 30 | 
            +
                @dataclass
         | 
| 31 | 
            +
                class Config(BaseModule.Config):
         | 
| 32 | 
            +
                    in_channels: int = 1024
         | 
| 33 | 
            +
                    out_channels: int = 40
         | 
| 34 | 
            +
                    scale_factor: int = 4
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    conv_layers: int = 4
         | 
| 37 | 
            +
                    conv_kernel_size: int = 3
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                cfg: Config
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def configure(self) -> None:
         | 
| 42 | 
            +
                    layers = []
         | 
| 43 | 
            +
                    output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    in_channels = self.cfg.in_channels
         | 
| 46 | 
            +
                    for i in range(self.cfg.conv_layers):
         | 
| 47 | 
            +
                        cur_out_channels = (
         | 
| 48 | 
            +
                            in_channels if i != self.cfg.conv_layers - 1 else output_channels
         | 
| 49 | 
            +
                        )
         | 
| 50 | 
            +
                        layers.append(
         | 
| 51 | 
            +
                            nn.Conv2d(
         | 
| 52 | 
            +
                                in_channels,
         | 
| 53 | 
            +
                                cur_out_channels,
         | 
| 54 | 
            +
                                self.cfg.conv_kernel_size,
         | 
| 55 | 
            +
                                padding=(self.cfg.conv_kernel_size - 1) // 2,
         | 
| 56 | 
            +
                            )
         | 
| 57 | 
            +
                        )
         | 
| 58 | 
            +
                        if i != self.cfg.conv_layers - 1:
         | 
| 59 | 
            +
                            layers.append(nn.ReLU(inplace=True))
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    layers.append(nn.PixelShuffle(self.cfg.scale_factor))
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    self.upsample = nn.Sequential(*layers)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def forward(
         | 
| 66 | 
            +
                    self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
         | 
| 67 | 
            +
                ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
         | 
| 68 | 
            +
                    return rearrange(
         | 
| 69 | 
            +
                        self.upsample(
         | 
| 70 | 
            +
                            rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
         | 
| 71 | 
            +
                        ),
         | 
| 72 | 
            +
                        "(B Np) Co Hp Wp -> B Np Co Hp Wp",
         | 
| 73 | 
            +
                        Np=3,
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            class _TruncExp(Function):  # pylint: disable=abstract-method
         | 
| 78 | 
            +
                # Implementation from torch-ngp:
         | 
| 79 | 
            +
                # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
         | 
| 80 | 
            +
                @staticmethod
         | 
| 81 | 
            +
                @conditional_decorator(
         | 
| 82 | 
            +
                    custom_fwd, "cuda" in get_device(), cast_inputs=torch.float32
         | 
| 83 | 
            +
                )
         | 
| 84 | 
            +
                def forward(ctx, x):  # pylint: disable=arguments-differ
         | 
| 85 | 
            +
                    ctx.save_for_backward(x)
         | 
| 86 | 
            +
                    return torch.exp(x)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                @staticmethod
         | 
| 89 | 
            +
                @conditional_decorator(custom_bwd, "cuda" in get_device())
         | 
| 90 | 
            +
                def backward(ctx, g):  # pylint: disable=arguments-differ
         | 
| 91 | 
            +
                    x = ctx.saved_tensors[0]
         | 
| 92 | 
            +
                    return g * torch.exp(torch.clamp(x, max=15))
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            trunc_exp = _TruncExp.apply
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def get_activation(name) -> Callable:
         | 
| 99 | 
            +
                if name is None:
         | 
| 100 | 
            +
                    return lambda x: x
         | 
| 101 | 
            +
                name = name.lower()
         | 
| 102 | 
            +
                if name == "none" or name == "linear" or name == "identity":
         | 
| 103 | 
            +
                    return lambda x: x
         | 
| 104 | 
            +
                elif name == "lin2srgb":
         | 
| 105 | 
            +
                    return lambda x: torch.where(
         | 
| 106 | 
            +
                        x > 0.0031308,
         | 
| 107 | 
            +
                        torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
         | 
| 108 | 
            +
                        12.92 * x,
         | 
| 109 | 
            +
                    ).clamp(0.0, 1.0)
         | 
| 110 | 
            +
                elif name == "exp":
         | 
| 111 | 
            +
                    return lambda x: torch.exp(x)
         | 
| 112 | 
            +
                elif name == "shifted_exp":
         | 
| 113 | 
            +
                    return lambda x: torch.exp(x - 1.0)
         | 
| 114 | 
            +
                elif name == "trunc_exp":
         | 
| 115 | 
            +
                    return trunc_exp
         | 
| 116 | 
            +
                elif name == "shifted_trunc_exp":
         | 
| 117 | 
            +
                    return lambda x: trunc_exp(x - 1.0)
         | 
| 118 | 
            +
                elif name == "sigmoid":
         | 
| 119 | 
            +
                    return lambda x: torch.sigmoid(x)
         | 
| 120 | 
            +
                elif name == "tanh":
         | 
| 121 | 
            +
                    return lambda x: torch.tanh(x)
         | 
| 122 | 
            +
                elif name == "shifted_softplus":
         | 
| 123 | 
            +
                    return lambda x: F.softplus(x - 1.0)
         | 
| 124 | 
            +
                elif name == "scale_-11_01":
         | 
| 125 | 
            +
                    return lambda x: x * 0.5 + 0.5
         | 
| 126 | 
            +
                elif name == "negative":
         | 
| 127 | 
            +
                    return lambda x: -x
         | 
| 128 | 
            +
                elif name == "normalize_channel_last":
         | 
| 129 | 
            +
                    return lambda x: normalize(x)
         | 
| 130 | 
            +
                elif name == "normalize_channel_first":
         | 
| 131 | 
            +
                    return lambda x: normalize(x, dim=1)
         | 
| 132 | 
            +
                else:
         | 
| 133 | 
            +
                    try:
         | 
| 134 | 
            +
                        return getattr(F, name)
         | 
| 135 | 
            +
                    except AttributeError:
         | 
| 136 | 
            +
                        raise ValueError(f"Unknown activation function: {name}")
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            class LambdaModule(torch.nn.Module):
         | 
| 140 | 
            +
                def __init__(self, lambd: Callable[[torch.Tensor], torch.Tensor]):
         | 
| 141 | 
            +
                    super().__init__()
         | 
| 142 | 
            +
                    self.lambd = lambd
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def forward(self, x):
         | 
| 145 | 
            +
                    return self.lambd(x)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def get_activation_module(name) -> torch.nn.Module:
         | 
| 149 | 
            +
                return LambdaModule(get_activation(name))
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
            @dataclass
         | 
| 153 | 
            +
            class HeadSpec:
         | 
| 154 | 
            +
                name: str
         | 
| 155 | 
            +
                out_channels: int
         | 
| 156 | 
            +
                n_hidden_layers: int
         | 
| 157 | 
            +
                output_activation: Optional[str] = None
         | 
| 158 | 
            +
                out_bias: float = 0.0
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            class MaterialMLP(BaseModule):
         | 
| 162 | 
            +
                @dataclass
         | 
| 163 | 
            +
                class Config(BaseModule.Config):
         | 
| 164 | 
            +
                    in_channels: int = 120
         | 
| 165 | 
            +
                    n_neurons: int = 64
         | 
| 166 | 
            +
                    activation: str = "silu"
         | 
| 167 | 
            +
                    heads: List[HeadSpec] = field(default_factory=lambda: [])
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                cfg: Config
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def configure(self) -> None:
         | 
| 172 | 
            +
                    assert len(self.cfg.heads) > 0
         | 
| 173 | 
            +
                    heads = {}
         | 
| 174 | 
            +
                    for head in self.cfg.heads:
         | 
| 175 | 
            +
                        head_layers = []
         | 
| 176 | 
            +
                        for i in range(head.n_hidden_layers):
         | 
| 177 | 
            +
                            head_layers += [
         | 
| 178 | 
            +
                                nn.Linear(
         | 
| 179 | 
            +
                                    self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
         | 
| 180 | 
            +
                                    self.cfg.n_neurons,
         | 
| 181 | 
            +
                                ),
         | 
| 182 | 
            +
                                self.make_activation(self.cfg.activation),
         | 
| 183 | 
            +
                            ]
         | 
| 184 | 
            +
                        head_layers += [
         | 
| 185 | 
            +
                            nn.Linear(
         | 
| 186 | 
            +
                                self.cfg.n_neurons,
         | 
| 187 | 
            +
                                head.out_channels,
         | 
| 188 | 
            +
                            ),
         | 
| 189 | 
            +
                        ]
         | 
| 190 | 
            +
                        heads[head.name] = nn.Sequential(*head_layers)
         | 
| 191 | 
            +
                    self.heads = nn.ModuleDict(heads)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def make_activation(self, activation):
         | 
| 194 | 
            +
                    if activation == "relu":
         | 
| 195 | 
            +
                        return nn.ReLU(inplace=True)
         | 
| 196 | 
            +
                    elif activation == "silu":
         | 
| 197 | 
            +
                        return nn.SiLU(inplace=True)
         | 
| 198 | 
            +
                    else:
         | 
| 199 | 
            +
                        raise NotImplementedError
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def keys(self):
         | 
| 202 | 
            +
                    return self.heads.keys()
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def forward(
         | 
| 205 | 
            +
                    self, x, include: Optional[List] = None, exclude: Optional[List] = None
         | 
| 206 | 
            +
                ):
         | 
| 207 | 
            +
                    if include is not None and exclude is not None:
         | 
| 208 | 
            +
                        raise ValueError("Cannot specify both include and exclude.")
         | 
| 209 | 
            +
                    if include is not None:
         | 
| 210 | 
            +
                        heads = [h for h in self.cfg.heads if h.name in include]
         | 
| 211 | 
            +
                    elif exclude is not None:
         | 
| 212 | 
            +
                        heads = [h for h in self.cfg.heads if h.name not in exclude]
         | 
| 213 | 
            +
                    else:
         | 
| 214 | 
            +
                        heads = self.cfg.heads
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    out = {
         | 
| 217 | 
            +
                        head.name: get_activation(head.output_activation)(
         | 
| 218 | 
            +
                            self.heads[head.name](x) + head.out_bias
         | 
| 219 | 
            +
                        )
         | 
| 220 | 
            +
                        for head in heads
         | 
| 221 | 
            +
                    }
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    return out
         | 
    	
        spar3d/models/tokenizers/dinov2.py
    ADDED
    
    | @@ -0,0 +1,1196 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
            """PyTorch DINOv2 model."""
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import collections.abc
         | 
| 18 | 
            +
            import math
         | 
| 19 | 
            +
            from dataclasses import dataclass
         | 
| 20 | 
            +
            from typing import Dict, List, Optional, Set, Tuple, Union
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            import torch
         | 
| 23 | 
            +
            import torch.nn.functional as F
         | 
| 24 | 
            +
            import torch.utils.checkpoint
         | 
| 25 | 
            +
            from torch import nn
         | 
| 26 | 
            +
            from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
         | 
| 27 | 
            +
            from transformers.activations import ACT2FN
         | 
| 28 | 
            +
            from transformers.modeling_outputs import (
         | 
| 29 | 
            +
                BackboneOutput,
         | 
| 30 | 
            +
                BaseModelOutput,
         | 
| 31 | 
            +
                BaseModelOutputWithPooling,
         | 
| 32 | 
            +
                ImageClassifierOutput,
         | 
| 33 | 
            +
            )
         | 
| 34 | 
            +
            from transformers.modeling_utils import PreTrainedModel
         | 
| 35 | 
            +
            from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
         | 
| 36 | 
            +
            from transformers.pytorch_utils import (
         | 
| 37 | 
            +
                find_pruneable_heads_and_indices,
         | 
| 38 | 
            +
                prune_linear_layer,
         | 
| 39 | 
            +
            )
         | 
| 40 | 
            +
            from transformers.utils import (
         | 
| 41 | 
            +
                add_code_sample_docstrings,
         | 
| 42 | 
            +
                add_start_docstrings,
         | 
| 43 | 
            +
                add_start_docstrings_to_model_forward,
         | 
| 44 | 
            +
                logging,
         | 
| 45 | 
            +
                replace_return_docstrings,
         | 
| 46 | 
            +
            )
         | 
| 47 | 
            +
            from transformers.utils.backbone_utils import BackboneMixin
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            # General docstring
         | 
| 52 | 
            +
            _CONFIG_FOR_DOC = "Dinov2Config"
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            # Base docstring
         | 
| 55 | 
            +
            _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
         | 
| 56 | 
            +
            _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            # Image classification docstring
         | 
| 59 | 
            +
            _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
         | 
| 63 | 
            +
                "facebook/dinov2-base",
         | 
| 64 | 
            +
                # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
         | 
| 65 | 
            +
            ]
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            class Dinov2Embeddings(nn.Module):
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                Construct the CLS token, mask token, position and patch embeddings.
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def __init__(self, config: Dinov2Config) -> None:
         | 
| 74 | 
            +
                    super().__init__()
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
         | 
| 77 | 
            +
                    # register as mask token as it's not used in optimization
         | 
| 78 | 
            +
                    # to avoid the use of find_unused_parameters_true
         | 
| 79 | 
            +
                    # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
         | 
| 80 | 
            +
                    self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
         | 
| 81 | 
            +
                    self.patch_embeddings = Dinov2PatchEmbeddings(config)
         | 
| 82 | 
            +
                    num_patches = self.patch_embeddings.num_patches
         | 
| 83 | 
            +
                    self.position_embeddings = nn.Parameter(
         | 
| 84 | 
            +
                        torch.randn(1, num_patches + 1, config.hidden_size)
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    self.dropout = nn.Dropout(config.hidden_dropout_prob)
         | 
| 87 | 
            +
                    self.config = config
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def interpolate_pos_encoding(
         | 
| 90 | 
            +
                    self, embeddings: torch.Tensor, height: int, width: int
         | 
| 91 | 
            +
                ) -> torch.Tensor:
         | 
| 92 | 
            +
                    """
         | 
| 93 | 
            +
                    This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
         | 
| 94 | 
            +
                    resolution images.
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    Source:
         | 
| 97 | 
            +
                    https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
         | 
| 98 | 
            +
                    """
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    num_patches = embeddings.shape[1] - 1
         | 
| 101 | 
            +
                    num_positions = self.position_embeddings.shape[1] - 1
         | 
| 102 | 
            +
                    if num_patches == num_positions and height == width:
         | 
| 103 | 
            +
                        return self.position_embeddings
         | 
| 104 | 
            +
                    class_pos_embed = self.position_embeddings[:, 0]
         | 
| 105 | 
            +
                    patch_pos_embed = self.position_embeddings[:, 1:]
         | 
| 106 | 
            +
                    dim = embeddings.shape[-1]
         | 
| 107 | 
            +
                    height = height // self.config.patch_size
         | 
| 108 | 
            +
                    width = width // self.config.patch_size
         | 
| 109 | 
            +
                    # we add a small number to avoid floating point error in the interpolation
         | 
| 110 | 
            +
                    # see discussion at https://github.com/facebookresearch/dino/issues/8
         | 
| 111 | 
            +
                    height, width = height + 0.1, width + 0.1
         | 
| 112 | 
            +
                    patch_pos_embed = patch_pos_embed.reshape(
         | 
| 113 | 
            +
                        1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
         | 
| 114 | 
            +
                    )
         | 
| 115 | 
            +
                    patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
         | 
| 116 | 
            +
                    patch_pos_embed = nn.functional.interpolate(
         | 
| 117 | 
            +
                        patch_pos_embed,
         | 
| 118 | 
            +
                        scale_factor=(
         | 
| 119 | 
            +
                            height / math.sqrt(num_positions),
         | 
| 120 | 
            +
                            width / math.sqrt(num_positions),
         | 
| 121 | 
            +
                        ),
         | 
| 122 | 
            +
                        mode="bicubic",
         | 
| 123 | 
            +
                        align_corners=False,
         | 
| 124 | 
            +
                    )
         | 
| 125 | 
            +
                    if (
         | 
| 126 | 
            +
                        int(height) != patch_pos_embed.shape[-2]
         | 
| 127 | 
            +
                        or int(width) != patch_pos_embed.shape[-1]
         | 
| 128 | 
            +
                    ):
         | 
| 129 | 
            +
                        raise ValueError(
         | 
| 130 | 
            +
                            "Width or height does not match with the interpolated position embeddings"
         | 
| 131 | 
            +
                        )
         | 
| 132 | 
            +
                    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         | 
| 133 | 
            +
                    return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def forward(
         | 
| 136 | 
            +
                    self,
         | 
| 137 | 
            +
                    pixel_values: torch.Tensor,
         | 
| 138 | 
            +
                    bool_masked_pos: Optional[torch.Tensor] = None,
         | 
| 139 | 
            +
                ) -> torch.Tensor:
         | 
| 140 | 
            +
                    batch_size, _, height, width = pixel_values.shape
         | 
| 141 | 
            +
                    patch_embeddings = self.patch_embeddings(pixel_values)
         | 
| 142 | 
            +
                    embeddings = patch_embeddings
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    if bool_masked_pos is not None:
         | 
| 145 | 
            +
                        embeddings = torch.where(
         | 
| 146 | 
            +
                            bool_masked_pos.unsqueeze(-1),
         | 
| 147 | 
            +
                            self.mask_token.to(embeddings.dtype).unsqueeze(0),
         | 
| 148 | 
            +
                            embeddings,
         | 
| 149 | 
            +
                        )
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    # add the [CLS] token to the embedded patch tokens
         | 
| 152 | 
            +
                    cls_tokens = self.cls_token.expand(batch_size, -1, -1)
         | 
| 153 | 
            +
                    embeddings = torch.cat((cls_tokens, embeddings), dim=1)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    # add positional encoding to each token
         | 
| 156 | 
            +
                    embeddings = embeddings + self.interpolate_pos_encoding(
         | 
| 157 | 
            +
                        embeddings, height, width
         | 
| 158 | 
            +
                    )
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    embeddings = self.dropout(embeddings)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    return embeddings
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            class Dinov2PatchEmbeddings(nn.Module):
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
         | 
| 168 | 
            +
                `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
         | 
| 169 | 
            +
                Transformer.
         | 
| 170 | 
            +
                """
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                def __init__(self, config):
         | 
| 173 | 
            +
                    super().__init__()
         | 
| 174 | 
            +
                    image_size, patch_size = config.image_size, config.patch_size
         | 
| 175 | 
            +
                    num_channels, hidden_size = config.num_channels, config.hidden_size
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    image_size = (
         | 
| 178 | 
            +
                        image_size
         | 
| 179 | 
            +
                        if isinstance(image_size, collections.abc.Iterable)
         | 
| 180 | 
            +
                        else (image_size, image_size)
         | 
| 181 | 
            +
                    )
         | 
| 182 | 
            +
                    patch_size = (
         | 
| 183 | 
            +
                        patch_size
         | 
| 184 | 
            +
                        if isinstance(patch_size, collections.abc.Iterable)
         | 
| 185 | 
            +
                        else (patch_size, patch_size)
         | 
| 186 | 
            +
                    )
         | 
| 187 | 
            +
                    num_patches = (image_size[1] // patch_size[1]) * (
         | 
| 188 | 
            +
                        image_size[0] // patch_size[0]
         | 
| 189 | 
            +
                    )
         | 
| 190 | 
            +
                    self.image_size = image_size
         | 
| 191 | 
            +
                    self.patch_size = patch_size
         | 
| 192 | 
            +
                    self.num_channels = num_channels
         | 
| 193 | 
            +
                    self.num_patches = num_patches
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    self.projection = nn.Conv2d(
         | 
| 196 | 
            +
                        num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
         | 
| 197 | 
            +
                    )
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
         | 
| 200 | 
            +
                    """
         | 
| 201 | 
            +
                    num_channels = pixel_values.shape[1]
         | 
| 202 | 
            +
                    if num_channels != self.num_channels:
         | 
| 203 | 
            +
                        raise ValueError(
         | 
| 204 | 
            +
                            "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
         | 
| 205 | 
            +
                            f" Expected {self.num_channels} but got {num_channels}."
         | 
| 206 | 
            +
                        )
         | 
| 207 | 
            +
                    """
         | 
| 208 | 
            +
                    embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
         | 
| 209 | 
            +
                    return embeddings
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
         | 
| 213 | 
            +
            class Dinov2SelfAttention(nn.Module):
         | 
| 214 | 
            +
                def __init__(self, config: Dinov2Config) -> None:
         | 
| 215 | 
            +
                    super().__init__()
         | 
| 216 | 
            +
                    if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
         | 
| 217 | 
            +
                        config, "embedding_size"
         | 
| 218 | 
            +
                    ):
         | 
| 219 | 
            +
                        raise ValueError(
         | 
| 220 | 
            +
                            f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
         | 
| 221 | 
            +
                            f"heads {config.num_attention_heads}."
         | 
| 222 | 
            +
                        )
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    self.num_attention_heads = config.num_attention_heads
         | 
| 225 | 
            +
                    self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
         | 
| 226 | 
            +
                    self.all_head_size = self.num_attention_heads * self.attention_head_size
         | 
| 227 | 
            +
                    self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    self.query = nn.Linear(
         | 
| 230 | 
            +
                        config.hidden_size, self.all_head_size, bias=config.qkv_bias
         | 
| 231 | 
            +
                    )
         | 
| 232 | 
            +
                    self.key = nn.Linear(
         | 
| 233 | 
            +
                        config.hidden_size, self.all_head_size, bias=config.qkv_bias
         | 
| 234 | 
            +
                    )
         | 
| 235 | 
            +
                    self.value = nn.Linear(
         | 
| 236 | 
            +
                        config.hidden_size, self.all_head_size, bias=config.qkv_bias
         | 
| 237 | 
            +
                    )
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 242 | 
            +
                    new_x_shape = x.size()[:-1] + (
         | 
| 243 | 
            +
                        self.num_attention_heads,
         | 
| 244 | 
            +
                        self.attention_head_size,
         | 
| 245 | 
            +
                    )
         | 
| 246 | 
            +
                    x = x.view(new_x_shape)
         | 
| 247 | 
            +
                    return x.permute(0, 2, 1, 3)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                def forward(
         | 
| 250 | 
            +
                    self,
         | 
| 251 | 
            +
                    hidden_states,
         | 
| 252 | 
            +
                    head_mask: Optional[torch.Tensor] = None,
         | 
| 253 | 
            +
                    output_attentions: bool = False,
         | 
| 254 | 
            +
                ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
         | 
| 255 | 
            +
                    mixed_query_layer = self.query(hidden_states)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    if hasattr(F, "scaled_dot_product_attention"):
         | 
| 258 | 
            +
                        assert head_mask is None and not output_attentions
         | 
| 259 | 
            +
                        new_size = hidden_states.size()[:-1] + (
         | 
| 260 | 
            +
                            self.num_attention_heads,
         | 
| 261 | 
            +
                            self.attention_head_size,
         | 
| 262 | 
            +
                        )
         | 
| 263 | 
            +
                        key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
         | 
| 264 | 
            +
                        value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
         | 
| 265 | 
            +
                        query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
         | 
| 266 | 
            +
                        context_layer = F.scaled_dot_product_attention(
         | 
| 267 | 
            +
                            query_layer,
         | 
| 268 | 
            +
                            key_layer,
         | 
| 269 | 
            +
                            value_layer,
         | 
| 270 | 
            +
                            dropout_p=self.attention_probs_dropout_prob,
         | 
| 271 | 
            +
                            is_causal=False,
         | 
| 272 | 
            +
                        )
         | 
| 273 | 
            +
                        context_layer = context_layer.transpose(1, 2).reshape(
         | 
| 274 | 
            +
                            *hidden_states.size()[:-1], -1
         | 
| 275 | 
            +
                        )
         | 
| 276 | 
            +
                    else:
         | 
| 277 | 
            +
                        key_layer = self.transpose_for_scores(self.key(hidden_states))
         | 
| 278 | 
            +
                        value_layer = self.transpose_for_scores(self.value(hidden_states))
         | 
| 279 | 
            +
                        query_layer = self.transpose_for_scores(mixed_query_layer)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                        # Take the dot product between "query" and "key" to get the raw attention scores.
         | 
| 282 | 
            +
                        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                        # Normalize the attention scores to probabilities.
         | 
| 287 | 
            +
                        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                        # This is actually dropping out entire tokens to attend to, which might
         | 
| 290 | 
            +
                        # seem a bit unusual, but is taken from the original Transformer paper.
         | 
| 291 | 
            +
                        attention_probs = self.dropout(attention_probs)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                        # Mask heads if we want to
         | 
| 294 | 
            +
                        if head_mask is not None:
         | 
| 295 | 
            +
                            attention_probs = attention_probs * head_mask
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                        context_layer = torch.matmul(attention_probs, value_layer)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         | 
| 300 | 
            +
                        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
         | 
| 301 | 
            +
                        context_layer = context_layer.view(new_context_layer_shape)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    outputs = (
         | 
| 304 | 
            +
                        (context_layer, attention_probs) if output_attentions else (context_layer,)
         | 
| 305 | 
            +
                    )
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    return outputs
         | 
| 308 | 
            +
             | 
| 309 | 
            +
             | 
| 310 | 
            +
            # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
         | 
| 311 | 
            +
            class Dinov2SelfOutput(nn.Module):
         | 
| 312 | 
            +
                """
         | 
| 313 | 
            +
                The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
         | 
| 314 | 
            +
                layernorm applied before each block.
         | 
| 315 | 
            +
                """
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                def __init__(self, config: Dinov2Config) -> None:
         | 
| 318 | 
            +
                    super().__init__()
         | 
| 319 | 
            +
                    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
         | 
| 320 | 
            +
                    self.dropout = nn.Dropout(config.hidden_dropout_prob)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                def forward(
         | 
| 323 | 
            +
                    self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
         | 
| 324 | 
            +
                ) -> torch.Tensor:
         | 
| 325 | 
            +
                    hidden_states = self.dense(hidden_states)
         | 
| 326 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    return hidden_states
         | 
| 329 | 
            +
             | 
| 330 | 
            +
             | 
| 331 | 
            +
            # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
         | 
| 332 | 
            +
            class Dinov2Attention(nn.Module):
         | 
| 333 | 
            +
                def __init__(self, config: Dinov2Config) -> None:
         | 
| 334 | 
            +
                    super().__init__()
         | 
| 335 | 
            +
                    self.attention = Dinov2SelfAttention(config)
         | 
| 336 | 
            +
                    self.output = Dinov2SelfOutput(config)
         | 
| 337 | 
            +
                    self.pruned_heads = set()
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                def prune_heads(self, heads: Set[int]) -> None:
         | 
| 340 | 
            +
                    if len(heads) == 0:
         | 
| 341 | 
            +
                        return
         | 
| 342 | 
            +
                    heads, index = find_pruneable_heads_and_indices(
         | 
| 343 | 
            +
                        heads,
         | 
| 344 | 
            +
                        self.attention.num_attention_heads,
         | 
| 345 | 
            +
                        self.attention.attention_head_size,
         | 
| 346 | 
            +
                        self.pruned_heads,
         | 
| 347 | 
            +
                    )
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    # Prune linear layers
         | 
| 350 | 
            +
                    self.attention.query = prune_linear_layer(self.attention.query, index)
         | 
| 351 | 
            +
                    self.attention.key = prune_linear_layer(self.attention.key, index)
         | 
| 352 | 
            +
                    self.attention.value = prune_linear_layer(self.attention.value, index)
         | 
| 353 | 
            +
                    self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    # Update hyper params and store pruned heads
         | 
| 356 | 
            +
                    self.attention.num_attention_heads = self.attention.num_attention_heads - len(
         | 
| 357 | 
            +
                        heads
         | 
| 358 | 
            +
                    )
         | 
| 359 | 
            +
                    self.attention.all_head_size = (
         | 
| 360 | 
            +
                        self.attention.attention_head_size * self.attention.num_attention_heads
         | 
| 361 | 
            +
                    )
         | 
| 362 | 
            +
                    self.pruned_heads = self.pruned_heads.union(heads)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                def forward(
         | 
| 365 | 
            +
                    self,
         | 
| 366 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 367 | 
            +
                    head_mask: Optional[torch.Tensor] = None,
         | 
| 368 | 
            +
                    output_attentions: bool = False,
         | 
| 369 | 
            +
                ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
         | 
| 370 | 
            +
                    self_outputs = self.attention(hidden_states, head_mask, output_attentions)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    attention_output = self.output(self_outputs[0], hidden_states)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    outputs = (attention_output,) + self_outputs[
         | 
| 375 | 
            +
                        1:
         | 
| 376 | 
            +
                    ]  # add attentions if we output them
         | 
| 377 | 
            +
                    return outputs
         | 
| 378 | 
            +
             | 
| 379 | 
            +
             | 
| 380 | 
            +
            class Dinov2LayerScale(nn.Module):
         | 
| 381 | 
            +
                def __init__(self, config) -> None:
         | 
| 382 | 
            +
                    super().__init__()
         | 
| 383 | 
            +
                    self.lambda1 = nn.Parameter(
         | 
| 384 | 
            +
                        config.layerscale_value * torch.ones(config.hidden_size)
         | 
| 385 | 
            +
                    )
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
         | 
| 388 | 
            +
                    return hidden_state * self.lambda1
         | 
| 389 | 
            +
             | 
| 390 | 
            +
             | 
| 391 | 
            +
            # Copied from transformers.models.beit.modeling_beit.drop_path
         | 
| 392 | 
            +
            def drop_path(
         | 
| 393 | 
            +
                input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
         | 
| 394 | 
            +
            ) -> torch.Tensor:
         | 
| 395 | 
            +
                """
         | 
| 396 | 
            +
                Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
         | 
| 399 | 
            +
                however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
         | 
| 400 | 
            +
                See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
         | 
| 401 | 
            +
                layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
         | 
| 402 | 
            +
                argument.
         | 
| 403 | 
            +
                """
         | 
| 404 | 
            +
                if drop_prob == 0.0 or not training:
         | 
| 405 | 
            +
                    return input
         | 
| 406 | 
            +
                keep_prob = 1 - drop_prob
         | 
| 407 | 
            +
                shape = (input.shape[0],) + (1,) * (
         | 
| 408 | 
            +
                    input.ndim - 1
         | 
| 409 | 
            +
                )  # work with diff dim tensors, not just 2D ConvNets
         | 
| 410 | 
            +
                random_tensor = keep_prob + torch.rand(
         | 
| 411 | 
            +
                    shape, dtype=input.dtype, device=input.device
         | 
| 412 | 
            +
                )
         | 
| 413 | 
            +
                random_tensor.floor_()  # binarize
         | 
| 414 | 
            +
                output = input.div(keep_prob) * random_tensor
         | 
| 415 | 
            +
                return output
         | 
| 416 | 
            +
             | 
| 417 | 
            +
             | 
| 418 | 
            +
            # Copied from transformers.models.beit.modeling_beit.BeitDropPath
         | 
| 419 | 
            +
            class Dinov2DropPath(nn.Module):
         | 
| 420 | 
            +
                """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                def __init__(self, drop_prob: Optional[float] = None) -> None:
         | 
| 423 | 
            +
                    super().__init__()
         | 
| 424 | 
            +
                    self.drop_prob = drop_prob
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         | 
| 427 | 
            +
                    return drop_path(hidden_states, self.drop_prob, self.training)
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                def extra_repr(self) -> str:
         | 
| 430 | 
            +
                    return "p={}".format(self.drop_prob)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
             | 
| 433 | 
            +
            class Dinov2MLP(nn.Module):
         | 
| 434 | 
            +
                def __init__(self, config) -> None:
         | 
| 435 | 
            +
                    super().__init__()
         | 
| 436 | 
            +
                    in_features = out_features = config.hidden_size
         | 
| 437 | 
            +
                    hidden_features = int(config.hidden_size * config.mlp_ratio)
         | 
| 438 | 
            +
                    self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
         | 
| 439 | 
            +
                    if isinstance(config.hidden_act, str):
         | 
| 440 | 
            +
                        self.activation = ACT2FN[config.hidden_act]
         | 
| 441 | 
            +
                    else:
         | 
| 442 | 
            +
                        self.activation = config.hidden_act
         | 
| 443 | 
            +
                    self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
         | 
| 446 | 
            +
                    hidden_state = self.fc1(hidden_state)
         | 
| 447 | 
            +
                    hidden_state = self.activation(hidden_state)
         | 
| 448 | 
            +
                    hidden_state = self.fc2(hidden_state)
         | 
| 449 | 
            +
                    return hidden_state
         | 
| 450 | 
            +
             | 
| 451 | 
            +
             | 
| 452 | 
            +
            class Dinov2SwiGLUFFN(nn.Module):
         | 
| 453 | 
            +
                def __init__(self, config) -> None:
         | 
| 454 | 
            +
                    super().__init__()
         | 
| 455 | 
            +
                    in_features = out_features = config.hidden_size
         | 
| 456 | 
            +
                    hidden_features = int(config.hidden_size * config.mlp_ratio)
         | 
| 457 | 
            +
                    hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
         | 
| 460 | 
            +
                    self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
         | 
| 463 | 
            +
                    hidden_state = self.weights_in(hidden_state)
         | 
| 464 | 
            +
                    x1, x2 = hidden_state.chunk(2, dim=-1)
         | 
| 465 | 
            +
                    hidden = nn.functional.silu(x1) * x2
         | 
| 466 | 
            +
                    return self.weights_out(hidden)
         | 
| 467 | 
            +
             | 
| 468 | 
            +
             | 
| 469 | 
            +
            class Dinov2Layer(nn.Module):
         | 
| 470 | 
            +
                """This corresponds to the Block class in the original implementation."""
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                def __init__(self, config: Dinov2Config) -> None:
         | 
| 473 | 
            +
                    super().__init__()
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         | 
| 476 | 
            +
                    self.norm1_modulation = None
         | 
| 477 | 
            +
                    self.attention = Dinov2Attention(config)
         | 
| 478 | 
            +
                    self.layer_scale1 = Dinov2LayerScale(config)
         | 
| 479 | 
            +
                    self.drop_path1 = (
         | 
| 480 | 
            +
                        Dinov2DropPath(config.drop_path_rate)
         | 
| 481 | 
            +
                        if config.drop_path_rate > 0.0
         | 
| 482 | 
            +
                        else nn.Identity()
         | 
| 483 | 
            +
                    )
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                    self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         | 
| 486 | 
            +
                    self.norm2_modulation = None
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    if config.use_swiglu_ffn:
         | 
| 489 | 
            +
                        self.mlp = Dinov2SwiGLUFFN(config)
         | 
| 490 | 
            +
                    else:
         | 
| 491 | 
            +
                        self.mlp = Dinov2MLP(config)
         | 
| 492 | 
            +
                    self.layer_scale2 = Dinov2LayerScale(config)
         | 
| 493 | 
            +
                    self.drop_path2 = (
         | 
| 494 | 
            +
                        Dinov2DropPath(config.drop_path_rate)
         | 
| 495 | 
            +
                        if config.drop_path_rate > 0.0
         | 
| 496 | 
            +
                        else nn.Identity()
         | 
| 497 | 
            +
                    )
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                def forward(
         | 
| 500 | 
            +
                    self,
         | 
| 501 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 502 | 
            +
                    head_mask: Optional[torch.Tensor] = None,
         | 
| 503 | 
            +
                    modulation_cond: Optional[torch.Tensor] = None,
         | 
| 504 | 
            +
                    output_attentions: bool = False,
         | 
| 505 | 
            +
                ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
         | 
| 506 | 
            +
                    hidden_states_norm = self.norm1(hidden_states)
         | 
| 507 | 
            +
                    if self.norm1_modulation is not None:
         | 
| 508 | 
            +
                        assert modulation_cond is not None
         | 
| 509 | 
            +
                        hidden_states_norm = self.norm1_modulation(
         | 
| 510 | 
            +
                            hidden_states_norm, modulation_cond
         | 
| 511 | 
            +
                        )
         | 
| 512 | 
            +
                    self_attention_outputs = self.attention(
         | 
| 513 | 
            +
                        hidden_states_norm,  # in Dinov2, layernorm is applied before self-attention
         | 
| 514 | 
            +
                        head_mask,
         | 
| 515 | 
            +
                        output_attentions=output_attentions,
         | 
| 516 | 
            +
                    )
         | 
| 517 | 
            +
                    attention_output = self_attention_outputs[0]
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                    attention_output = self.layer_scale1(attention_output)
         | 
| 520 | 
            +
                    outputs = self_attention_outputs[
         | 
| 521 | 
            +
                        1:
         | 
| 522 | 
            +
                    ]  # add self attentions if we output attention weights
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                    # first residual connection
         | 
| 525 | 
            +
                    hidden_states = attention_output + hidden_states
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                    # in Dinov2, layernorm is also applied after self-attention
         | 
| 528 | 
            +
                    layer_output = self.norm2(hidden_states)
         | 
| 529 | 
            +
                    if self.norm2_modulation is not None:
         | 
| 530 | 
            +
                        assert modulation_cond is not None
         | 
| 531 | 
            +
                        layer_output = self.norm2_modulation(layer_output, modulation_cond)
         | 
| 532 | 
            +
                    layer_output = self.mlp(layer_output)
         | 
| 533 | 
            +
                    layer_output = self.layer_scale2(layer_output)
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    # second residual connection
         | 
| 536 | 
            +
                    layer_output = layer_output + hidden_states
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                    outputs = (layer_output,) + outputs
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    return outputs
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
         | 
| 543 | 
            +
                    self.norm1_modulation = norm1_mod
         | 
| 544 | 
            +
                    self.norm2_modulation = norm2_mod
         | 
| 545 | 
            +
             | 
| 546 | 
            +
             | 
| 547 | 
            +
            # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
         | 
| 548 | 
            +
            class Dinov2Encoder(nn.Module):
         | 
| 549 | 
            +
                def __init__(self, config: Dinov2Config) -> None:
         | 
| 550 | 
            +
                    super().__init__()
         | 
| 551 | 
            +
                    self.config = config
         | 
| 552 | 
            +
                    self.layer = nn.ModuleList(
         | 
| 553 | 
            +
                        [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
         | 
| 554 | 
            +
                    )
         | 
| 555 | 
            +
                    self.gradient_checkpointing = False
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                def forward(
         | 
| 558 | 
            +
                    self,
         | 
| 559 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 560 | 
            +
                    head_mask: Optional[torch.Tensor] = None,
         | 
| 561 | 
            +
                    modulation_cond: Optional[torch.Tensor] = None,
         | 
| 562 | 
            +
                    output_attentions: bool = False,
         | 
| 563 | 
            +
                    output_hidden_states: bool = False,
         | 
| 564 | 
            +
                    return_dict: bool = True,
         | 
| 565 | 
            +
                ) -> Union[tuple, BaseModelOutput]:
         | 
| 566 | 
            +
                    all_hidden_states = () if output_hidden_states else None
         | 
| 567 | 
            +
                    all_self_attentions = () if output_attentions else None
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                    for i, layer_module in enumerate(self.layer):
         | 
| 570 | 
            +
                        if output_hidden_states:
         | 
| 571 | 
            +
                            all_hidden_states = all_hidden_states + (hidden_states,)
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                        layer_head_mask = head_mask[i] if head_mask is not None else None
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                        if self.gradient_checkpointing and self.training:
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                            def create_custom_forward(module):
         | 
| 578 | 
            +
                                def custom_forward(*inputs):
         | 
| 579 | 
            +
                                    return module(*inputs, output_attentions)
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                                return custom_forward
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                            layer_outputs = torch.utils.checkpoint.checkpoint(
         | 
| 584 | 
            +
                                create_custom_forward(layer_module),
         | 
| 585 | 
            +
                                hidden_states,
         | 
| 586 | 
            +
                                layer_head_mask,
         | 
| 587 | 
            +
                                modulation_cond,
         | 
| 588 | 
            +
                                use_reentrant=False,
         | 
| 589 | 
            +
                            )
         | 
| 590 | 
            +
                        else:
         | 
| 591 | 
            +
                            layer_outputs = layer_module(
         | 
| 592 | 
            +
                                hidden_states, layer_head_mask, modulation_cond, output_attentions
         | 
| 593 | 
            +
                            )
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                        hidden_states = layer_outputs[0]
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                        if output_attentions:
         | 
| 598 | 
            +
                            all_self_attentions = all_self_attentions + (layer_outputs[1],)
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                    if output_hidden_states:
         | 
| 601 | 
            +
                        all_hidden_states = all_hidden_states + (hidden_states,)
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                    if not return_dict:
         | 
| 604 | 
            +
                        return tuple(
         | 
| 605 | 
            +
                            v
         | 
| 606 | 
            +
                            for v in [hidden_states, all_hidden_states, all_self_attentions]
         | 
| 607 | 
            +
                            if v is not None
         | 
| 608 | 
            +
                        )
         | 
| 609 | 
            +
                    return BaseModelOutput(
         | 
| 610 | 
            +
                        last_hidden_state=hidden_states,
         | 
| 611 | 
            +
                        hidden_states=all_hidden_states,
         | 
| 612 | 
            +
                        attentions=all_self_attentions,
         | 
| 613 | 
            +
                    )
         | 
| 614 | 
            +
             | 
| 615 | 
            +
             | 
| 616 | 
            +
            class Dinov2PreTrainedModel(PreTrainedModel):
         | 
| 617 | 
            +
                """
         | 
| 618 | 
            +
                An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
         | 
| 619 | 
            +
                models.
         | 
| 620 | 
            +
                """
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                config_class = Dinov2Config
         | 
| 623 | 
            +
                base_model_prefix = "dinov2"
         | 
| 624 | 
            +
                main_input_name = "pixel_values"
         | 
| 625 | 
            +
                supports_gradient_checkpointing = True
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
         | 
| 628 | 
            +
                    """Initialize the weights"""
         | 
| 629 | 
            +
                    if isinstance(module, (nn.Linear, nn.Conv2d)):
         | 
| 630 | 
            +
                        # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
         | 
| 631 | 
            +
                        # `trunc_normal_cpu` not implemented in `half` issues
         | 
| 632 | 
            +
                        module.weight.data = nn.init.trunc_normal_(
         | 
| 633 | 
            +
                            module.weight.data.to(torch.float32),
         | 
| 634 | 
            +
                            mean=0.0,
         | 
| 635 | 
            +
                            std=self.config.initializer_range,
         | 
| 636 | 
            +
                        ).to(module.weight.dtype)
         | 
| 637 | 
            +
                        if module.bias is not None:
         | 
| 638 | 
            +
                            module.bias.data.zero_()
         | 
| 639 | 
            +
                    elif isinstance(module, nn.LayerNorm):
         | 
| 640 | 
            +
                        module.bias.data.zero_()
         | 
| 641 | 
            +
                        module.weight.data.fill_(1.0)
         | 
| 642 | 
            +
                    elif isinstance(module, Dinov2Embeddings):
         | 
| 643 | 
            +
                        module.position_embeddings.data = nn.init.trunc_normal_(
         | 
| 644 | 
            +
                            module.position_embeddings.data.to(torch.float32),
         | 
| 645 | 
            +
                            mean=0.0,
         | 
| 646 | 
            +
                            std=self.config.initializer_range,
         | 
| 647 | 
            +
                        ).to(module.position_embeddings.dtype)
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                        module.cls_token.data = nn.init.trunc_normal_(
         | 
| 650 | 
            +
                            module.cls_token.data.to(torch.float32),
         | 
| 651 | 
            +
                            mean=0.0,
         | 
| 652 | 
            +
                            std=self.config.initializer_range,
         | 
| 653 | 
            +
                        ).to(module.cls_token.dtype)
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                def _set_gradient_checkpointing(
         | 
| 656 | 
            +
                    self, module: Dinov2Encoder, value: bool = False
         | 
| 657 | 
            +
                ) -> None:
         | 
| 658 | 
            +
                    if isinstance(module, Dinov2Encoder):
         | 
| 659 | 
            +
                        module.gradient_checkpointing = value
         | 
| 660 | 
            +
             | 
| 661 | 
            +
             | 
| 662 | 
            +
            DINOV2_START_DOCSTRING = r"""
         | 
| 663 | 
            +
                This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
         | 
| 664 | 
            +
                as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
         | 
| 665 | 
            +
                behavior.
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                Parameters:
         | 
| 668 | 
            +
                    config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
         | 
| 669 | 
            +
                        Initializing with a config file does not load the weights associated with the model, only the
         | 
| 670 | 
            +
                        configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
         | 
| 671 | 
            +
            """
         | 
| 672 | 
            +
             | 
| 673 | 
            +
            DINOV2_BASE_INPUTS_DOCSTRING = r"""
         | 
| 674 | 
            +
                Args:
         | 
| 675 | 
            +
                    pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
         | 
| 676 | 
            +
                        Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
         | 
| 677 | 
            +
                        [`BitImageProcessor.preprocess`] for details.
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
         | 
| 680 | 
            +
                        Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
         | 
| 681 | 
            +
                        pre-training.
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                    head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
         | 
| 684 | 
            +
                        Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
         | 
| 685 | 
            +
             | 
| 686 | 
            +
                        - 1 indicates the head is **not masked**,
         | 
| 687 | 
            +
                        - 0 indicates the head is **masked**.
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                    output_attentions (`bool`, *optional*):
         | 
| 690 | 
            +
                        Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
         | 
| 691 | 
            +
                        tensors for more detail.
         | 
| 692 | 
            +
                    output_hidden_states (`bool`, *optional*):
         | 
| 693 | 
            +
                        Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
         | 
| 694 | 
            +
                        more detail.
         | 
| 695 | 
            +
                    return_dict (`bool`, *optional*):
         | 
| 696 | 
            +
                        Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
         | 
| 697 | 
            +
            """
         | 
| 698 | 
            +
             | 
| 699 | 
            +
            DINOV2_INPUTS_DOCSTRING = r"""
         | 
| 700 | 
            +
                Args:
         | 
| 701 | 
            +
                    pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
         | 
| 702 | 
            +
                        Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
         | 
| 703 | 
            +
                        [`BitImageProcessor.preprocess`] for details.
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                    head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
         | 
| 706 | 
            +
                        Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                        - 1 indicates the head is **not masked**,
         | 
| 709 | 
            +
                        - 0 indicates the head is **masked**.
         | 
| 710 | 
            +
             | 
| 711 | 
            +
                    output_attentions (`bool`, *optional*):
         | 
| 712 | 
            +
                        Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
         | 
| 713 | 
            +
                        tensors for more detail.
         | 
| 714 | 
            +
                    output_hidden_states (`bool`, *optional*):
         | 
| 715 | 
            +
                        Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
         | 
| 716 | 
            +
                        more detail.
         | 
| 717 | 
            +
                    return_dict (`bool`, *optional*):
         | 
| 718 | 
            +
                        Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
         | 
| 719 | 
            +
            """
         | 
| 720 | 
            +
             | 
| 721 | 
            +
             | 
| 722 | 
            +
            @dataclass
         | 
| 723 | 
            +
            class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
         | 
| 724 | 
            +
                patch_embeddings: Optional[torch.FloatTensor] = None
         | 
| 725 | 
            +
             | 
| 726 | 
            +
             | 
| 727 | 
            +
            @add_start_docstrings(
         | 
| 728 | 
            +
                "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
         | 
| 729 | 
            +
                DINOV2_START_DOCSTRING,
         | 
| 730 | 
            +
            )
         | 
| 731 | 
            +
            class Dinov2Model(Dinov2PreTrainedModel):
         | 
| 732 | 
            +
                def __init__(self, config: Dinov2Config):
         | 
| 733 | 
            +
                    super().__init__(config)
         | 
| 734 | 
            +
                    self.config = config
         | 
| 735 | 
            +
             | 
| 736 | 
            +
                    self.embeddings = Dinov2Embeddings(config)
         | 
| 737 | 
            +
                    self.encoder = Dinov2Encoder(config)
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                    self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         | 
| 740 | 
            +
             | 
| 741 | 
            +
                    # Initialize weights and apply final processing
         | 
| 742 | 
            +
                    self.post_init()
         | 
| 743 | 
            +
             | 
| 744 | 
            +
                def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
         | 
| 745 | 
            +
                    return self.embeddings.patch_embeddings
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                def expand_input_channels(self, extra_input_channels: int) -> None:
         | 
| 748 | 
            +
                    if extra_input_channels == 0:
         | 
| 749 | 
            +
                        return
         | 
| 750 | 
            +
                    conv_old = self.embeddings.patch_embeddings.projection
         | 
| 751 | 
            +
                    conv_new = nn.Conv2d(
         | 
| 752 | 
            +
                        self.config.num_channels + extra_input_channels,
         | 
| 753 | 
            +
                        self.config.hidden_size,
         | 
| 754 | 
            +
                        kernel_size=self.config.patch_size,
         | 
| 755 | 
            +
                        stride=self.config.patch_size,
         | 
| 756 | 
            +
                    ).to(self.device)
         | 
| 757 | 
            +
                    with torch.no_grad():
         | 
| 758 | 
            +
                        conv_new.weight[:, :3] = conv_old.weight
         | 
| 759 | 
            +
                        conv_new.bias = conv_old.bias
         | 
| 760 | 
            +
                    self.embeddings.patch_embeddings.projection = conv_new
         | 
| 761 | 
            +
                    del conv_old
         | 
| 762 | 
            +
             | 
| 763 | 
            +
                def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
         | 
| 764 | 
            +
                    """
         | 
| 765 | 
            +
                    Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
         | 
| 766 | 
            +
                    class PreTrainedModel
         | 
| 767 | 
            +
                    """
         | 
| 768 | 
            +
                    for layer, heads in heads_to_prune.items():
         | 
| 769 | 
            +
                        self.encoder.layer[layer].attention.prune_heads(heads)
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
         | 
| 772 | 
            +
                @add_code_sample_docstrings(
         | 
| 773 | 
            +
                    checkpoint=_CHECKPOINT_FOR_DOC,
         | 
| 774 | 
            +
                    output_type=BaseModelOutputWithPooling,
         | 
| 775 | 
            +
                    config_class=_CONFIG_FOR_DOC,
         | 
| 776 | 
            +
                    modality="vision",
         | 
| 777 | 
            +
                    expected_output=_EXPECTED_OUTPUT_SHAPE,
         | 
| 778 | 
            +
                )
         | 
| 779 | 
            +
                def forward(
         | 
| 780 | 
            +
                    self,
         | 
| 781 | 
            +
                    pixel_values: Optional[torch.Tensor] = None,
         | 
| 782 | 
            +
                    bool_masked_pos: Optional[torch.Tensor] = None,
         | 
| 783 | 
            +
                    head_mask: Optional[torch.Tensor] = None,
         | 
| 784 | 
            +
                    modulation_cond: Optional[torch.Tensor] = None,
         | 
| 785 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 786 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 787 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 788 | 
            +
                ) -> Union[Tuple, BaseModelOutputWithPooling]:
         | 
| 789 | 
            +
                    output_attentions = (
         | 
| 790 | 
            +
                        output_attentions
         | 
| 791 | 
            +
                        if output_attentions is not None
         | 
| 792 | 
            +
                        else self.config.output_attentions
         | 
| 793 | 
            +
                    )
         | 
| 794 | 
            +
                    output_hidden_states = (
         | 
| 795 | 
            +
                        output_hidden_states
         | 
| 796 | 
            +
                        if output_hidden_states is not None
         | 
| 797 | 
            +
                        else self.config.output_hidden_states
         | 
| 798 | 
            +
                    )
         | 
| 799 | 
            +
                    return_dict = (
         | 
| 800 | 
            +
                        return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 801 | 
            +
                    )
         | 
| 802 | 
            +
             | 
| 803 | 
            +
                    if pixel_values is None:
         | 
| 804 | 
            +
                        raise ValueError("You have to specify pixel_values")
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                    # Prepare head mask if needed
         | 
| 807 | 
            +
                    # 1.0 in head_mask indicate we keep the head
         | 
| 808 | 
            +
                    # attention_probs has shape bsz x n_heads x N x N
         | 
| 809 | 
            +
                    # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
         | 
| 810 | 
            +
                    # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
         | 
| 811 | 
            +
                    head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
         | 
| 812 | 
            +
             | 
| 813 | 
            +
                    embedding_output = self.embeddings(
         | 
| 814 | 
            +
                        pixel_values, bool_masked_pos=bool_masked_pos
         | 
| 815 | 
            +
                    )
         | 
| 816 | 
            +
             | 
| 817 | 
            +
                    encoder_outputs = self.encoder(
         | 
| 818 | 
            +
                        embedding_output,
         | 
| 819 | 
            +
                        head_mask=head_mask,
         | 
| 820 | 
            +
                        modulation_cond=modulation_cond,
         | 
| 821 | 
            +
                        output_attentions=output_attentions,
         | 
| 822 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 823 | 
            +
                        return_dict=return_dict,
         | 
| 824 | 
            +
                    )
         | 
| 825 | 
            +
                    sequence_output = encoder_outputs[0]
         | 
| 826 | 
            +
                    sequence_output = self.layernorm(sequence_output)
         | 
| 827 | 
            +
                    pooled_output = sequence_output[:, 0, :]
         | 
| 828 | 
            +
             | 
| 829 | 
            +
                    if not return_dict:
         | 
| 830 | 
            +
                        head_outputs = (sequence_output, pooled_output)
         | 
| 831 | 
            +
                        return head_outputs + encoder_outputs[1:]
         | 
| 832 | 
            +
             | 
| 833 | 
            +
                    return CustomBaseModelOutputWithPooling(
         | 
| 834 | 
            +
                        last_hidden_state=sequence_output,
         | 
| 835 | 
            +
                        pooler_output=pooled_output,
         | 
| 836 | 
            +
                        hidden_states=encoder_outputs.hidden_states,
         | 
| 837 | 
            +
                        attentions=encoder_outputs.attentions,
         | 
| 838 | 
            +
                        patch_embeddings=embedding_output,
         | 
| 839 | 
            +
                    )
         | 
| 840 | 
            +
             | 
| 841 | 
            +
                def set_gradient_checkpointing(self, value: bool = False) -> None:
         | 
| 842 | 
            +
                    self._set_gradient_checkpointing(self.encoder, value)
         | 
| 843 | 
            +
             | 
| 844 | 
            +
             | 
| 845 | 
            +
            @add_start_docstrings(
         | 
| 846 | 
            +
                """
         | 
| 847 | 
            +
                Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
         | 
| 848 | 
            +
                of the [CLS] token) e.g. for ImageNet.
         | 
| 849 | 
            +
                """,
         | 
| 850 | 
            +
                DINOV2_START_DOCSTRING,
         | 
| 851 | 
            +
            )
         | 
| 852 | 
            +
            class Dinov2ForImageClassification(Dinov2PreTrainedModel):
         | 
| 853 | 
            +
                def __init__(self, config: Dinov2Config) -> None:
         | 
| 854 | 
            +
                    super().__init__(config)
         | 
| 855 | 
            +
             | 
| 856 | 
            +
                    self.num_labels = config.num_labels
         | 
| 857 | 
            +
                    self.dinov2 = Dinov2Model(config)
         | 
| 858 | 
            +
             | 
| 859 | 
            +
                    # Classifier head
         | 
| 860 | 
            +
                    self.classifier = (
         | 
| 861 | 
            +
                        nn.Linear(config.hidden_size * 2, config.num_labels)
         | 
| 862 | 
            +
                        if config.num_labels > 0
         | 
| 863 | 
            +
                        else nn.Identity()
         | 
| 864 | 
            +
                    )
         | 
| 865 | 
            +
             | 
| 866 | 
            +
                    # Initialize weights and apply final processing
         | 
| 867 | 
            +
                    self.post_init()
         | 
| 868 | 
            +
             | 
| 869 | 
            +
                @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
         | 
| 870 | 
            +
                @add_code_sample_docstrings(
         | 
| 871 | 
            +
                    checkpoint=_IMAGE_CLASS_CHECKPOINT,
         | 
| 872 | 
            +
                    output_type=ImageClassifierOutput,
         | 
| 873 | 
            +
                    config_class=_CONFIG_FOR_DOC,
         | 
| 874 | 
            +
                )
         | 
| 875 | 
            +
                def forward(
         | 
| 876 | 
            +
                    self,
         | 
| 877 | 
            +
                    pixel_values: Optional[torch.Tensor] = None,
         | 
| 878 | 
            +
                    head_mask: Optional[torch.Tensor] = None,
         | 
| 879 | 
            +
                    labels: Optional[torch.Tensor] = None,
         | 
| 880 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 881 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 882 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 883 | 
            +
                ) -> Union[tuple, ImageClassifierOutput]:
         | 
| 884 | 
            +
                    r"""
         | 
| 885 | 
            +
                    labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
         | 
| 886 | 
            +
                        Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
         | 
| 887 | 
            +
                        config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
         | 
| 888 | 
            +
                        `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
         | 
| 889 | 
            +
                    """
         | 
| 890 | 
            +
                    return_dict = (
         | 
| 891 | 
            +
                        return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 892 | 
            +
                    )
         | 
| 893 | 
            +
             | 
| 894 | 
            +
                    outputs = self.dinov2(
         | 
| 895 | 
            +
                        pixel_values,
         | 
| 896 | 
            +
                        head_mask=head_mask,
         | 
| 897 | 
            +
                        output_attentions=output_attentions,
         | 
| 898 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 899 | 
            +
                        return_dict=return_dict,
         | 
| 900 | 
            +
                    )
         | 
| 901 | 
            +
             | 
| 902 | 
            +
                    sequence_output = outputs[0]  # batch_size, sequence_length, hidden_size
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                    cls_token = sequence_output[:, 0]
         | 
| 905 | 
            +
                    patch_tokens = sequence_output[:, 1:]
         | 
| 906 | 
            +
             | 
| 907 | 
            +
                    linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
         | 
| 908 | 
            +
             | 
| 909 | 
            +
                    logits = self.classifier(linear_input)
         | 
| 910 | 
            +
             | 
| 911 | 
            +
                    loss = None
         | 
| 912 | 
            +
                    if labels is not None:
         | 
| 913 | 
            +
                        # move labels to correct device to enable model parallelism
         | 
| 914 | 
            +
                        labels = labels.to(logits.device)
         | 
| 915 | 
            +
                        if self.config.problem_type is None:
         | 
| 916 | 
            +
                            if self.num_labels == 1:
         | 
| 917 | 
            +
                                self.config.problem_type = "regression"
         | 
| 918 | 
            +
                            elif self.num_labels > 1 and (
         | 
| 919 | 
            +
                                labels.dtype == torch.long or labels.dtype == torch.int
         | 
| 920 | 
            +
                            ):
         | 
| 921 | 
            +
                                self.config.problem_type = "single_label_classification"
         | 
| 922 | 
            +
                            else:
         | 
| 923 | 
            +
                                self.config.problem_type = "multi_label_classification"
         | 
| 924 | 
            +
             | 
| 925 | 
            +
                        if self.config.problem_type == "regression":
         | 
| 926 | 
            +
                            loss_fct = MSELoss()
         | 
| 927 | 
            +
                            if self.num_labels == 1:
         | 
| 928 | 
            +
                                loss = loss_fct(logits.squeeze(), labels.squeeze())
         | 
| 929 | 
            +
                            else:
         | 
| 930 | 
            +
                                loss = loss_fct(logits, labels)
         | 
| 931 | 
            +
                        elif self.config.problem_type == "single_label_classification":
         | 
| 932 | 
            +
                            loss_fct = CrossEntropyLoss()
         | 
| 933 | 
            +
                            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
         | 
| 934 | 
            +
                        elif self.config.problem_type == "multi_label_classification":
         | 
| 935 | 
            +
                            loss_fct = BCEWithLogitsLoss()
         | 
| 936 | 
            +
                            loss = loss_fct(logits, labels)
         | 
| 937 | 
            +
             | 
| 938 | 
            +
                    if not return_dict:
         | 
| 939 | 
            +
                        output = (logits,) + outputs[2:]
         | 
| 940 | 
            +
                        return ((loss,) + output) if loss is not None else output
         | 
| 941 | 
            +
             | 
| 942 | 
            +
                    return ImageClassifierOutput(
         | 
| 943 | 
            +
                        loss=loss,
         | 
| 944 | 
            +
                        logits=logits,
         | 
| 945 | 
            +
                        hidden_states=outputs.hidden_states,
         | 
| 946 | 
            +
                        attentions=outputs.attentions,
         | 
| 947 | 
            +
                    )
         | 
| 948 | 
            +
             | 
| 949 | 
            +
             | 
| 950 | 
            +
            @add_start_docstrings(
         | 
| 951 | 
            +
                """
         | 
| 952 | 
            +
                Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
         | 
| 953 | 
            +
                """,
         | 
| 954 | 
            +
                DINOV2_START_DOCSTRING,
         | 
| 955 | 
            +
            )
         | 
| 956 | 
            +
            class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
         | 
| 957 | 
            +
                def __init__(self, config):
         | 
| 958 | 
            +
                    super().__init__(config)
         | 
| 959 | 
            +
                    super()._init_backbone(config)
         | 
| 960 | 
            +
             | 
| 961 | 
            +
                    self.num_features = [
         | 
| 962 | 
            +
                        config.hidden_size for _ in range(config.num_hidden_layers + 1)
         | 
| 963 | 
            +
                    ]
         | 
| 964 | 
            +
                    self.embeddings = Dinov2Embeddings(config)
         | 
| 965 | 
            +
                    self.encoder = Dinov2Encoder(config)
         | 
| 966 | 
            +
             | 
| 967 | 
            +
                    self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         | 
| 968 | 
            +
             | 
| 969 | 
            +
                    # Initialize weights and apply final processing
         | 
| 970 | 
            +
                    self.post_init()
         | 
| 971 | 
            +
             | 
| 972 | 
            +
                def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
         | 
| 973 | 
            +
                    return self.embeddings.patch_embeddings
         | 
| 974 | 
            +
             | 
| 975 | 
            +
                @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
         | 
| 976 | 
            +
                @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
         | 
| 977 | 
            +
                def forward(
         | 
| 978 | 
            +
                    self,
         | 
| 979 | 
            +
                    pixel_values: torch.Tensor,
         | 
| 980 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 981 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 982 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 983 | 
            +
                ) -> BackboneOutput:
         | 
| 984 | 
            +
                    """
         | 
| 985 | 
            +
                    Returns:
         | 
| 986 | 
            +
             | 
| 987 | 
            +
                    Examples:
         | 
| 988 | 
            +
             | 
| 989 | 
            +
                    ```python
         | 
| 990 | 
            +
                    >>> from transformers import AutoImageProcessor, AutoBackbone
         | 
| 991 | 
            +
                    >>> import torch
         | 
| 992 | 
            +
                    >>> from PIL import Image
         | 
| 993 | 
            +
                    >>> import requests
         | 
| 994 | 
            +
             | 
| 995 | 
            +
                    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
         | 
| 996 | 
            +
                    >>> image = Image.open(requests.get(url, stream=True).raw)
         | 
| 997 | 
            +
             | 
| 998 | 
            +
                    >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
         | 
| 999 | 
            +
                    >>> model = AutoBackbone.from_pretrained(
         | 
| 1000 | 
            +
                    ...     "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
         | 
| 1001 | 
            +
                    ... )
         | 
| 1002 | 
            +
             | 
| 1003 | 
            +
                    >>> inputs = processor(image, return_tensors="pt")
         | 
| 1004 | 
            +
             | 
| 1005 | 
            +
                    >>> outputs = model(**inputs)
         | 
| 1006 | 
            +
                    >>> feature_maps = outputs.feature_maps
         | 
| 1007 | 
            +
                    >>> list(feature_maps[-1].shape)
         | 
| 1008 | 
            +
                    [1, 768, 16, 16]
         | 
| 1009 | 
            +
                    ```"""
         | 
| 1010 | 
            +
                    return_dict = (
         | 
| 1011 | 
            +
                        return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 1012 | 
            +
                    )
         | 
| 1013 | 
            +
                    output_hidden_states = (
         | 
| 1014 | 
            +
                        output_hidden_states
         | 
| 1015 | 
            +
                        if output_hidden_states is not None
         | 
| 1016 | 
            +
                        else self.config.output_hidden_states
         | 
| 1017 | 
            +
                    )
         | 
| 1018 | 
            +
                    output_attentions = (
         | 
| 1019 | 
            +
                        output_attentions
         | 
| 1020 | 
            +
                        if output_attentions is not None
         | 
| 1021 | 
            +
                        else self.config.output_attentions
         | 
| 1022 | 
            +
                    )
         | 
| 1023 | 
            +
             | 
| 1024 | 
            +
                    embedding_output = self.embeddings(pixel_values)
         | 
| 1025 | 
            +
             | 
| 1026 | 
            +
                    outputs = self.encoder(
         | 
| 1027 | 
            +
                        embedding_output,
         | 
| 1028 | 
            +
                        output_hidden_states=True,
         | 
| 1029 | 
            +
                        output_attentions=output_attentions,
         | 
| 1030 | 
            +
                        return_dict=return_dict,
         | 
| 1031 | 
            +
                    )
         | 
| 1032 | 
            +
             | 
| 1033 | 
            +
                    hidden_states = outputs.hidden_states if return_dict else outputs[1]
         | 
| 1034 | 
            +
             | 
| 1035 | 
            +
                    feature_maps = ()
         | 
| 1036 | 
            +
                    for stage, hidden_state in zip(self.stage_names, hidden_states):
         | 
| 1037 | 
            +
                        if stage in self.out_features:
         | 
| 1038 | 
            +
                            if self.config.apply_layernorm:
         | 
| 1039 | 
            +
                                hidden_state = self.layernorm(hidden_state)
         | 
| 1040 | 
            +
                            if self.config.reshape_hidden_states:
         | 
| 1041 | 
            +
                                batch_size, _, height, width = pixel_values.shape
         | 
| 1042 | 
            +
                                patch_size = self.config.patch_size
         | 
| 1043 | 
            +
                                hidden_state = hidden_state[:, 1:, :].reshape(
         | 
| 1044 | 
            +
                                    batch_size, width // patch_size, height // patch_size, -1
         | 
| 1045 | 
            +
                                )
         | 
| 1046 | 
            +
                                hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
         | 
| 1047 | 
            +
                            feature_maps += (hidden_state,)
         | 
| 1048 | 
            +
             | 
| 1049 | 
            +
                    if not return_dict:
         | 
| 1050 | 
            +
                        if output_hidden_states:
         | 
| 1051 | 
            +
                            output = (feature_maps,) + outputs[1:]
         | 
| 1052 | 
            +
                        else:
         | 
| 1053 | 
            +
                            output = (feature_maps,) + outputs[2:]
         | 
| 1054 | 
            +
                        return output
         | 
| 1055 | 
            +
             | 
| 1056 | 
            +
                    return BackboneOutput(
         | 
| 1057 | 
            +
                        feature_maps=feature_maps,
         | 
| 1058 | 
            +
                        hidden_states=outputs.hidden_states if output_hidden_states else None,
         | 
| 1059 | 
            +
                        attentions=outputs.attentions if output_attentions else None,
         | 
| 1060 | 
            +
                    )
         | 
| 1061 | 
            +
             | 
| 1062 | 
            +
             | 
| 1063 | 
            +
            class CustomPatchEmbeddings(nn.Module):
         | 
| 1064 | 
            +
                """
         | 
| 1065 | 
            +
                This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
         | 
| 1066 | 
            +
                `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
         | 
| 1067 | 
            +
                Transformer.
         | 
| 1068 | 
            +
                """
         | 
| 1069 | 
            +
             | 
| 1070 | 
            +
                def __init__(
         | 
| 1071 | 
            +
                    self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
         | 
| 1072 | 
            +
                ):
         | 
| 1073 | 
            +
                    super().__init__()
         | 
| 1074 | 
            +
             | 
| 1075 | 
            +
                    image_size = (
         | 
| 1076 | 
            +
                        image_size
         | 
| 1077 | 
            +
                        if isinstance(image_size, collections.abc.Iterable)
         | 
| 1078 | 
            +
                        else (image_size, image_size)
         | 
| 1079 | 
            +
                    )
         | 
| 1080 | 
            +
                    patch_size = (
         | 
| 1081 | 
            +
                        patch_size
         | 
| 1082 | 
            +
                        if isinstance(patch_size, collections.abc.Iterable)
         | 
| 1083 | 
            +
                        else (patch_size, patch_size)
         | 
| 1084 | 
            +
                    )
         | 
| 1085 | 
            +
                    num_patches = (image_size[1] // patch_size[1]) * (
         | 
| 1086 | 
            +
                        image_size[0] // patch_size[0]
         | 
| 1087 | 
            +
                    )
         | 
| 1088 | 
            +
                    self.image_size = image_size
         | 
| 1089 | 
            +
                    self.patch_size = patch_size
         | 
| 1090 | 
            +
                    self.num_channels = num_channels
         | 
| 1091 | 
            +
                    self.num_patches = num_patches
         | 
| 1092 | 
            +
             | 
| 1093 | 
            +
                    self.projection = nn.Conv2d(
         | 
| 1094 | 
            +
                        num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
         | 
| 1095 | 
            +
                    )
         | 
| 1096 | 
            +
             | 
| 1097 | 
            +
                def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
         | 
| 1098 | 
            +
                    num_channels = pixel_values.shape[1]
         | 
| 1099 | 
            +
                    if num_channels != self.num_channels:
         | 
| 1100 | 
            +
                        raise ValueError(
         | 
| 1101 | 
            +
                            "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
         | 
| 1102 | 
            +
                            f" Expected {self.num_channels} but got {num_channels}."
         | 
| 1103 | 
            +
                        )
         | 
| 1104 | 
            +
                    embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
         | 
| 1105 | 
            +
                    return embeddings
         | 
| 1106 | 
            +
             | 
| 1107 | 
            +
             | 
| 1108 | 
            +
            class CustomEmbeddings(nn.Module):
         | 
| 1109 | 
            +
                """
         | 
| 1110 | 
            +
                Construct the CLS token, mask token, position and patch embeddings.
         | 
| 1111 | 
            +
                """
         | 
| 1112 | 
            +
             | 
| 1113 | 
            +
                def __init__(
         | 
| 1114 | 
            +
                    self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
         | 
| 1115 | 
            +
                ) -> None:
         | 
| 1116 | 
            +
                    super().__init__()
         | 
| 1117 | 
            +
             | 
| 1118 | 
            +
                    self.image_size = image_size
         | 
| 1119 | 
            +
                    self.patch_size = patch_size
         | 
| 1120 | 
            +
                    self.num_channels = num_channels
         | 
| 1121 | 
            +
                    self.hidden_size = hidden_size
         | 
| 1122 | 
            +
             | 
| 1123 | 
            +
                    self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
         | 
| 1124 | 
            +
             | 
| 1125 | 
            +
                    self.patch_embeddings = CustomPatchEmbeddings(
         | 
| 1126 | 
            +
                        image_size, patch_size, num_channels, hidden_size
         | 
| 1127 | 
            +
                    )
         | 
| 1128 | 
            +
                    num_patches = self.patch_embeddings.num_patches
         | 
| 1129 | 
            +
                    self.position_embeddings = nn.Parameter(
         | 
| 1130 | 
            +
                        torch.randn(1, num_patches + 1, self.hidden_size)
         | 
| 1131 | 
            +
                    )
         | 
| 1132 | 
            +
             | 
| 1133 | 
            +
                def interpolate_pos_encoding(
         | 
| 1134 | 
            +
                    self, embeddings: torch.Tensor, height: int, width: int
         | 
| 1135 | 
            +
                ) -> torch.Tensor:
         | 
| 1136 | 
            +
                    """
         | 
| 1137 | 
            +
                    This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
         | 
| 1138 | 
            +
                    resolution images.
         | 
| 1139 | 
            +
             | 
| 1140 | 
            +
                    Source:
         | 
| 1141 | 
            +
                    https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
         | 
| 1142 | 
            +
                    """
         | 
| 1143 | 
            +
             | 
| 1144 | 
            +
                    num_patches = embeddings.shape[1] - 1
         | 
| 1145 | 
            +
                    num_positions = self.position_embeddings.shape[1] - 1
         | 
| 1146 | 
            +
                    if num_patches == num_positions and height == width:
         | 
| 1147 | 
            +
                        return self.position_embeddings
         | 
| 1148 | 
            +
                    class_pos_embed = self.position_embeddings[:, 0]
         | 
| 1149 | 
            +
                    patch_pos_embed = self.position_embeddings[:, 1:]
         | 
| 1150 | 
            +
                    dim = embeddings.shape[-1]
         | 
| 1151 | 
            +
                    height = height // self.patch_size
         | 
| 1152 | 
            +
                    width = width // self.patch_size
         | 
| 1153 | 
            +
                    # we add a small number to avoid floating point error in the interpolation
         | 
| 1154 | 
            +
                    # see discussion at https://github.com/facebookresearch/dino/issues/8
         | 
| 1155 | 
            +
                    height, width = height + 0.1, width + 0.1
         | 
| 1156 | 
            +
                    patch_pos_embed = patch_pos_embed.reshape(
         | 
| 1157 | 
            +
                        1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
         | 
| 1158 | 
            +
                    )
         | 
| 1159 | 
            +
                    patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
         | 
| 1160 | 
            +
                    patch_pos_embed = nn.functional.interpolate(
         | 
| 1161 | 
            +
                        patch_pos_embed,
         | 
| 1162 | 
            +
                        scale_factor=(
         | 
| 1163 | 
            +
                            height / math.sqrt(num_positions),
         | 
| 1164 | 
            +
                            width / math.sqrt(num_positions),
         | 
| 1165 | 
            +
                        ),
         | 
| 1166 | 
            +
                        mode="bicubic",
         | 
| 1167 | 
            +
                        align_corners=False,
         | 
| 1168 | 
            +
                    )
         | 
| 1169 | 
            +
                    if (
         | 
| 1170 | 
            +
                        int(height) != patch_pos_embed.shape[-2]
         | 
| 1171 | 
            +
                        or int(width) != patch_pos_embed.shape[-1]
         | 
| 1172 | 
            +
                    ):
         | 
| 1173 | 
            +
                        raise ValueError(
         | 
| 1174 | 
            +
                            "Width or height does not match with the interpolated position embeddings"
         | 
| 1175 | 
            +
                        )
         | 
| 1176 | 
            +
                    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         | 
| 1177 | 
            +
                    return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
         | 
| 1178 | 
            +
             | 
| 1179 | 
            +
                def forward(
         | 
| 1180 | 
            +
                    self,
         | 
| 1181 | 
            +
                    pixel_values: torch.Tensor,
         | 
| 1182 | 
            +
                ) -> torch.Tensor:
         | 
| 1183 | 
            +
                    batch_size, _, height, width = pixel_values.shape
         | 
| 1184 | 
            +
                    patch_embeddings = self.patch_embeddings(pixel_values)
         | 
| 1185 | 
            +
                    embeddings = patch_embeddings
         | 
| 1186 | 
            +
             | 
| 1187 | 
            +
                    # add the [CLS] token to the embedded patch tokens
         | 
| 1188 | 
            +
                    cls_tokens = self.cls_token.expand(batch_size, -1, -1)
         | 
| 1189 | 
            +
                    embeddings = torch.cat((cls_tokens, embeddings), dim=1)
         | 
| 1190 | 
            +
             | 
| 1191 | 
            +
                    # add positional encoding to each token
         | 
| 1192 | 
            +
                    embeddings = embeddings + self.interpolate_pos_encoding(
         | 
| 1193 | 
            +
                        embeddings, height, width
         | 
| 1194 | 
            +
                    )
         | 
| 1195 | 
            +
             | 
| 1196 | 
            +
                    return embeddings
         | 
    	
        spar3d/models/tokenizers/image.py
    ADDED
    
    | @@ -0,0 +1,99 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
            from typing import Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
            from jaxtyping import Float
         | 
| 8 | 
            +
            from torch import Tensor
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from spar3d.models.tokenizers.dinov2 import Dinov2Model
         | 
| 11 | 
            +
            from spar3d.models.transformers.attention import Modulation
         | 
| 12 | 
            +
            from spar3d.models.utils import BaseModule
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class DINOV2SingleImageTokenizer(BaseModule):
         | 
| 16 | 
            +
                @dataclass
         | 
| 17 | 
            +
                class Config(BaseModule.Config):
         | 
| 18 | 
            +
                    pretrained_model_name_or_path: str = "facebook/dinov2-large"
         | 
| 19 | 
            +
                    width: int = 512
         | 
| 20 | 
            +
                    height: int = 512
         | 
| 21 | 
            +
                    modulation_cond_dim: int = 768
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                cfg: Config
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def configure(self) -> None:
         | 
| 26 | 
            +
                    self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    for p in self.model.parameters():
         | 
| 29 | 
            +
                        p.requires_grad_(False)
         | 
| 30 | 
            +
                    self.model.eval()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    self.model.set_gradient_checkpointing(False)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    # add modulation
         | 
| 35 | 
            +
                    modulations = []
         | 
| 36 | 
            +
                    for layer in self.model.encoder.layer:
         | 
| 37 | 
            +
                        norm1_modulation = Modulation(
         | 
| 38 | 
            +
                            self.model.config.hidden_size,
         | 
| 39 | 
            +
                            self.cfg.modulation_cond_dim,
         | 
| 40 | 
            +
                            zero_init=True,
         | 
| 41 | 
            +
                            single_layer=True,
         | 
| 42 | 
            +
                        )
         | 
| 43 | 
            +
                        norm2_modulation = Modulation(
         | 
| 44 | 
            +
                            self.model.config.hidden_size,
         | 
| 45 | 
            +
                            self.cfg.modulation_cond_dim,
         | 
| 46 | 
            +
                            zero_init=True,
         | 
| 47 | 
            +
                            single_layer=True,
         | 
| 48 | 
            +
                        )
         | 
| 49 | 
            +
                        layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
         | 
| 50 | 
            +
                        modulations += [norm1_modulation, norm2_modulation]
         | 
| 51 | 
            +
                    self.modulations = nn.ModuleList(modulations)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    self.register_buffer(
         | 
| 54 | 
            +
                        "image_mean",
         | 
| 55 | 
            +
                        torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
         | 
| 56 | 
            +
                        persistent=False,
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    self.register_buffer(
         | 
| 59 | 
            +
                        "image_std",
         | 
| 60 | 
            +
                        torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
         | 
| 61 | 
            +
                        persistent=False,
         | 
| 62 | 
            +
                    )
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def forward(
         | 
| 65 | 
            +
                    self,
         | 
| 66 | 
            +
                    images: Float[Tensor, "B *N C H W"],
         | 
| 67 | 
            +
                    modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
         | 
| 68 | 
            +
                    **kwargs,
         | 
| 69 | 
            +
                ) -> Float[Tensor, "B *N Ct Nt"]:
         | 
| 70 | 
            +
                    model = self.model
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    packed = False
         | 
| 73 | 
            +
                    if images.ndim == 4:
         | 
| 74 | 
            +
                        packed = True
         | 
| 75 | 
            +
                        images = images.unsqueeze(1)
         | 
| 76 | 
            +
                        if modulation_cond is not None:
         | 
| 77 | 
            +
                            assert modulation_cond.ndim == 2
         | 
| 78 | 
            +
                            modulation_cond = modulation_cond.unsqueeze(1)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    batch_size, n_input_views = images.shape[:2]
         | 
| 81 | 
            +
                    images = (images - self.image_mean) / self.image_std
         | 
| 82 | 
            +
                    out = model(
         | 
| 83 | 
            +
                        rearrange(images, "B N C H W -> (B N) C H W"),
         | 
| 84 | 
            +
                        modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
         | 
| 85 | 
            +
                        if modulation_cond is not None
         | 
| 86 | 
            +
                        else None,
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
                    local_features = out.last_hidden_state
         | 
| 89 | 
            +
                    local_features = local_features.permute(0, 2, 1)
         | 
| 90 | 
            +
                    local_features = rearrange(
         | 
| 91 | 
            +
                        local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    if packed:
         | 
| 94 | 
            +
                        local_features = local_features.squeeze(1)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    return local_features
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def detokenize(self, *args, **kwargs):
         | 
| 99 | 
            +
                    raise NotImplementedError
         | 
    	
        spar3d/models/tokenizers/point.py
    ADDED
    
    | @@ -0,0 +1,51 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
            from typing import Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from jaxtyping import Float
         | 
| 6 | 
            +
            from torch import Tensor
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from spar3d.models.transformers.transformer_1d import Transformer1D
         | 
| 9 | 
            +
            from spar3d.models.utils import BaseModule
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class TransformerPointTokenizer(BaseModule):
         | 
| 13 | 
            +
                @dataclass
         | 
| 14 | 
            +
                class Config(BaseModule.Config):
         | 
| 15 | 
            +
                    num_attention_heads: int = 16
         | 
| 16 | 
            +
                    attention_head_dim: int = 64
         | 
| 17 | 
            +
                    in_channels: Optional[int] = 6
         | 
| 18 | 
            +
                    out_channels: Optional[int] = 1024
         | 
| 19 | 
            +
                    num_layers: int = 16
         | 
| 20 | 
            +
                    norm_num_groups: int = 32
         | 
| 21 | 
            +
                    attention_bias: bool = False
         | 
| 22 | 
            +
                    activation_fn: str = "geglu"
         | 
| 23 | 
            +
                    norm_elementwise_affine: bool = True
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                cfg: Config
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def configure(self) -> None:
         | 
| 28 | 
            +
                    transformer_cfg = dict(self.cfg.copy())
         | 
| 29 | 
            +
                    # remove the non-transformer configs
         | 
| 30 | 
            +
                    transformer_cfg["in_channels"] = (
         | 
| 31 | 
            +
                        self.cfg.num_attention_heads * self.cfg.attention_head_dim
         | 
| 32 | 
            +
                    )
         | 
| 33 | 
            +
                    self.model = Transformer1D(transformer_cfg)
         | 
| 34 | 
            +
                    self.linear_in = torch.nn.Linear(
         | 
| 35 | 
            +
                        self.cfg.in_channels, transformer_cfg["in_channels"]
         | 
| 36 | 
            +
                    )
         | 
| 37 | 
            +
                    self.linear_out = torch.nn.Linear(
         | 
| 38 | 
            +
                        transformer_cfg["in_channels"], self.cfg.out_channels
         | 
| 39 | 
            +
                    )
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def forward(
         | 
| 42 | 
            +
                    self, points: Float[Tensor, "B N Ci"], **kwargs
         | 
| 43 | 
            +
                ) -> Float[Tensor, "B N Cp"]:
         | 
| 44 | 
            +
                    assert points.ndim == 3
         | 
| 45 | 
            +
                    inputs = self.linear_in(points).permute(0, 2, 1)  # B N Ci -> B Ci N
         | 
| 46 | 
            +
                    out = self.model(inputs).permute(0, 2, 1)  # B Ci N -> B N Ci
         | 
| 47 | 
            +
                    out = self.linear_out(out)  # B N Ci -> B N Co
         | 
| 48 | 
            +
                    return out
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def detokenize(self, *args, **kwargs):
         | 
| 51 | 
            +
                    raise NotImplementedError
         | 
    	
        spar3d/models/tokenizers/triplane.py
    ADDED
    
    | @@ -0,0 +1,49 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from dataclasses import dataclass
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from einops import rearrange, repeat
         | 
| 7 | 
            +
            from jaxtyping import Float
         | 
| 8 | 
            +
            from torch import Tensor
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from spar3d.models.utils import BaseModule
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class TriplaneLearnablePositionalEmbedding(BaseModule):
         | 
| 14 | 
            +
                @dataclass
         | 
| 15 | 
            +
                class Config(BaseModule.Config):
         | 
| 16 | 
            +
                    plane_size: int = 96
         | 
| 17 | 
            +
                    num_channels: int = 1024
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                cfg: Config
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def configure(self) -> None:
         | 
| 22 | 
            +
                    self.embeddings = nn.Parameter(
         | 
| 23 | 
            +
                        torch.randn(
         | 
| 24 | 
            +
                            (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
         | 
| 25 | 
            +
                            dtype=torch.float32,
         | 
| 26 | 
            +
                        )
         | 
| 27 | 
            +
                        * 1
         | 
| 28 | 
            +
                        / math.sqrt(self.cfg.num_channels)
         | 
| 29 | 
            +
                    )
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
         | 
| 32 | 
            +
                    return rearrange(
         | 
| 33 | 
            +
                        repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
         | 
| 34 | 
            +
                        "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def detokenize(
         | 
| 38 | 
            +
                    self, tokens: Float[Tensor, "B Ct Nt"]
         | 
| 39 | 
            +
                ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
         | 
| 40 | 
            +
                    batch_size, Ct, Nt = tokens.shape
         | 
| 41 | 
            +
                    assert Nt == self.cfg.plane_size**2 * 3
         | 
| 42 | 
            +
                    assert Ct == self.cfg.num_channels
         | 
| 43 | 
            +
                    return rearrange(
         | 
| 44 | 
            +
                        tokens,
         | 
| 45 | 
            +
                        "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
         | 
| 46 | 
            +
                        Np=3,
         | 
| 47 | 
            +
                        Hp=self.cfg.plane_size,
         | 
| 48 | 
            +
                        Wp=self.cfg.plane_size,
         | 
| 49 | 
            +
                    )
         | 
