Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						c7f097c
	
0
								Parent(s):
							
							
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +27 -0
- .gitignore +46 -0
- PIFu/.gitignore +1 -0
- PIFu/LICENSE.txt +48 -0
- PIFu/README.md +167 -0
- PIFu/apps/__init__.py +0 -0
- PIFu/apps/crop_img.py +75 -0
- PIFu/apps/eval.py +123 -0
- PIFu/apps/eval_spaces.py +130 -0
- PIFu/apps/prt_util.py +142 -0
- PIFu/apps/render_data.py +290 -0
- PIFu/apps/train_color.py +191 -0
- PIFu/apps/train_shape.py +183 -0
- PIFu/env_sh.npy +0 -0
- PIFu/environment.yml +19 -0
- PIFu/inputs/.gitignore +2 -0
- PIFu/lib/__init__.py +0 -0
- PIFu/lib/colab_util.py +114 -0
- PIFu/lib/data/BaseDataset.py +46 -0
- PIFu/lib/data/EvalDataset.py +166 -0
- PIFu/lib/data/TrainDataset.py +390 -0
- PIFu/lib/data/__init__.py +2 -0
- PIFu/lib/ext_transform.py +78 -0
- PIFu/lib/geometry.py +55 -0
- PIFu/lib/mesh_util.py +91 -0
- PIFu/lib/model/BasePIFuNet.py +76 -0
- PIFu/lib/model/ConvFilters.py +112 -0
- PIFu/lib/model/ConvPIFuNet.py +99 -0
- PIFu/lib/model/DepthNormalizer.py +18 -0
- PIFu/lib/model/HGFilters.py +146 -0
- PIFu/lib/model/HGPIFuNet.py +142 -0
- PIFu/lib/model/ResBlkPIFuNet.py +201 -0
- PIFu/lib/model/SurfaceClassifier.py +71 -0
- PIFu/lib/model/VhullPIFuNet.py +70 -0
- PIFu/lib/model/__init__.py +5 -0
- PIFu/lib/net_util.py +396 -0
- PIFu/lib/options.py +157 -0
- PIFu/lib/renderer/__init__.py +0 -0
- PIFu/lib/renderer/camera.py +207 -0
- PIFu/lib/renderer/gl/__init__.py +0 -0
- PIFu/lib/renderer/gl/cam_render.py +48 -0
- PIFu/lib/renderer/gl/data/prt.fs +153 -0
- PIFu/lib/renderer/gl/data/prt.vs +167 -0
- PIFu/lib/renderer/gl/data/prt_uv.fs +141 -0
- PIFu/lib/renderer/gl/data/prt_uv.vs +168 -0
- PIFu/lib/renderer/gl/data/quad.fs +11 -0
- PIFu/lib/renderer/gl/data/quad.vs +11 -0
- PIFu/lib/renderer/gl/framework.py +90 -0
- PIFu/lib/renderer/gl/glcontext.py +142 -0
- PIFu/lib/renderer/gl/init_gl.py +24 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            *.zstandard filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,46 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Python build
         | 
| 2 | 
            +
            .eggs/
         | 
| 3 | 
            +
            gradio.egg-info/*
         | 
| 4 | 
            +
            !gradio.egg-info/requires.txt
         | 
| 5 | 
            +
            !gradio.egg-info/PKG-INFO
         | 
| 6 | 
            +
            dist/
         | 
| 7 | 
            +
            *.pyc
         | 
| 8 | 
            +
            __pycache__/
         | 
| 9 | 
            +
            *.py[cod]
         | 
| 10 | 
            +
            *$py.class
         | 
| 11 | 
            +
            build/
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # JS build
         | 
| 14 | 
            +
            gradio/templates/frontend
         | 
| 15 | 
            +
            # Secrets
         | 
| 16 | 
            +
            .env
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Gradio run artifacts
         | 
| 19 | 
            +
            *.db
         | 
| 20 | 
            +
            *.sqlite3
         | 
| 21 | 
            +
            gradio/launches.json
         | 
| 22 | 
            +
            flagged/
         | 
| 23 | 
            +
            gradio_cached_examples/
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            # Tests
         | 
| 26 | 
            +
            .coverage
         | 
| 27 | 
            +
            coverage.xml
         | 
| 28 | 
            +
            test.txt
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # Demos
         | 
| 31 | 
            +
            demo/tmp.zip
         | 
| 32 | 
            +
            demo/files/*.avi
         | 
| 33 | 
            +
            demo/files/*.mp4
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            # Etc
         | 
| 36 | 
            +
            .idea/*
         | 
| 37 | 
            +
            .DS_Store
         | 
| 38 | 
            +
            *.bak
         | 
| 39 | 
            +
            workspace.code-workspace
         | 
| 40 | 
            +
            *.h5
         | 
| 41 | 
            +
            .vscode/
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            # log files
         | 
| 44 | 
            +
            .pnpm-debug.log
         | 
| 45 | 
            +
            venv/
         | 
| 46 | 
            +
            *.db-journal
         | 
    	
        PIFu/.gitignore
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            checkpoints/*
         | 
    	
        PIFu/LICENSE.txt
    ADDED
    
    | @@ -0,0 +1,48 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) 2019 Shunsuke Saito, Zeng Huang, and Ryota Natsume
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            anyabagomo
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            -------------------- LICENSE FOR ResBlk Image Encoder -----------------------
         | 
| 26 | 
            +
            Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
         | 
| 27 | 
            +
            All rights reserved.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            Redistribution and use in source and binary forms, with or without
         | 
| 30 | 
            +
            modification, are permitted provided that the following conditions are met:
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            * Redistributions of source code must retain the above copyright notice, this
         | 
| 33 | 
            +
              list of conditions and the following disclaimer.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            * Redistributions in binary form must reproduce the above copyright notice,
         | 
| 36 | 
            +
              this list of conditions and the following disclaimer in the documentation
         | 
| 37 | 
            +
              and/or other materials provided with the distribution.
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
         | 
| 40 | 
            +
            AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
         | 
| 41 | 
            +
            IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
         | 
| 42 | 
            +
            DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
         | 
| 43 | 
            +
            FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
         | 
| 44 | 
            +
            DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
         | 
| 45 | 
            +
            SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
         | 
| 46 | 
            +
            CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
         | 
| 47 | 
            +
            OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
         | 
| 48 | 
            +
            OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
         | 
    	
        PIFu/README.md
    ADDED
    
    | @@ -0,0 +1,167 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            [](https://arxiv.org/abs/1905.05172) [](https://colab.research.google.com/drive/1GFSsqP2BWz4gtq0e-nki00ZHSirXwFyY)
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            News:
         | 
| 6 | 
            +
            * \[2020/05/04\] Added EGL rendering option for training data generation. Now you can create your own training data with headless machines!
         | 
| 7 | 
            +
            * \[2020/04/13\] Demo with Google Colab (incl. visualization) is available. Special thanks to [@nanopoteto](https://github.com/nanopoteto)!!!
         | 
| 8 | 
            +
            * \[2020/02/26\] License is updated to MIT license! Enjoy!
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            This repository contains a pytorch implementation of "[PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization](https://arxiv.org/abs/1905.05172)".
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            [Project Page](https://shunsukesaito.github.io/PIFu/)
         | 
| 13 | 
            +
            
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            If you find the code useful in your research, please consider citing the paper.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            ```
         | 
| 18 | 
            +
            @InProceedings{saito2019pifu,
         | 
| 19 | 
            +
            author = {Saito, Shunsuke and Huang, Zeng and Natsume, Ryota and Morishima, Shigeo and Kanazawa, Angjoo and Li, Hao},
         | 
| 20 | 
            +
            title = {PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization},
         | 
| 21 | 
            +
            booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
         | 
| 22 | 
            +
            month = {October},
         | 
| 23 | 
            +
            year = {2019}
         | 
| 24 | 
            +
            }
         | 
| 25 | 
            +
            ```
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            This codebase provides: 
         | 
| 29 | 
            +
            - test code
         | 
| 30 | 
            +
            - training code
         | 
| 31 | 
            +
            - data generation code
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            ## Requirements
         | 
| 34 | 
            +
            - Python 3
         | 
| 35 | 
            +
            - [PyTorch](https://pytorch.org/) tested on 1.4.0
         | 
| 36 | 
            +
            - json
         | 
| 37 | 
            +
            - PIL
         | 
| 38 | 
            +
            - skimage
         | 
| 39 | 
            +
            - tqdm
         | 
| 40 | 
            +
            - numpy
         | 
| 41 | 
            +
            - cv2
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            for training and data generation
         | 
| 44 | 
            +
            - [trimesh](https://trimsh.org/) with [pyembree](https://github.com/scopatz/pyembree)
         | 
| 45 | 
            +
            - [pyexr](https://github.com/tvogels/pyexr)
         | 
| 46 | 
            +
            - PyOpenGL
         | 
| 47 | 
            +
            - freeglut (use `sudo apt-get install freeglut3-dev` for ubuntu users)
         | 
| 48 | 
            +
            - (optional) egl related packages for rendering with headless machines. (use `apt install libgl1-mesa-dri libegl1-mesa libgbm1` for ubuntu users)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            Warning: I found that outdated NVIDIA drivers may cause errors with EGL. If you want to try out the EGL version, please update your NVIDIA driver to the latest!!
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            ## Windows demo installation instuction
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            - Install [miniconda](https://docs.conda.io/en/latest/miniconda.html)
         | 
| 55 | 
            +
            - Add `conda` to PATH
         | 
| 56 | 
            +
            - Install [git bash](https://git-scm.com/downloads)
         | 
| 57 | 
            +
            - Launch `Git\bin\bash.exe`
         | 
| 58 | 
            +
            - `eval "$(conda shell.bash hook)"` then `conda activate my_env` because of [this](https://github.com/conda/conda-build/issues/3371)
         | 
| 59 | 
            +
            - Automatic `env create -f environment.yml` (look [this](https://github.com/conda/conda/issues/3417))
         | 
| 60 | 
            +
            - OR manually setup [environment](https://towardsdatascience.com/a-guide-to-conda-environments-bc6180fc533)
         | 
| 61 | 
            +
                - `conda create —name pifu python` where `pifu` is name of your environment
         | 
| 62 | 
            +
                - `conda activate`
         | 
| 63 | 
            +
                - `conda install pytorch torchvision cudatoolkit=10.1 -c pytorch`
         | 
| 64 | 
            +
                - `conda install pillow`
         | 
| 65 | 
            +
                - `conda install scikit-image`
         | 
| 66 | 
            +
                - `conda install tqdm`
         | 
| 67 | 
            +
                - `conda install -c menpo opencv`
         | 
| 68 | 
            +
            - Download [wget.exe](https://eternallybored.org/misc/wget/)
         | 
| 69 | 
            +
            - Place it into `Git\mingw64\bin`
         | 
| 70 | 
            +
            - `sh ./scripts/download_trained_model.sh`
         | 
| 71 | 
            +
            - Remove background from your image ([this](https://www.remove.bg/), for example)
         | 
| 72 | 
            +
            - Create black-white mask .png
         | 
| 73 | 
            +
            - Replace original from sample_images/
         | 
| 74 | 
            +
            - Try it out - `sh ./scripts/test.sh`
         | 
| 75 | 
            +
            - Download [Meshlab](http://www.meshlab.net/) because of [this](https://github.com/shunsukesaito/PIFu/issues/1)
         | 
| 76 | 
            +
            - Open .obj file in Meshlab
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            ## Demo
         | 
| 80 | 
            +
            Warning: The released model is trained with mostly upright standing scans with weak perspectie projection and the pitch angle of 0 degree. Reconstruction quality may degrade for images highly deviated from trainining data.
         | 
| 81 | 
            +
            1. run the following script to download the pretrained models from the following link and copy them under `./PIFu/checkpoints/`.
         | 
| 82 | 
            +
            ```
         | 
| 83 | 
            +
            sh ./scripts/download_trained_model.sh
         | 
| 84 | 
            +
            ```
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            2. run the following script. the script creates a textured `.obj` file under `./PIFu/eval_results/`. You may need to use `./apps/crop_img.py` to roughly align an input image and the corresponding mask to the training data for better performance. For background removal, you can use any off-the-shelf tools such as [removebg](https://www.remove.bg/).
         | 
| 87 | 
            +
            ```
         | 
| 88 | 
            +
            sh ./scripts/test.sh
         | 
| 89 | 
            +
            ```
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            ## Demo on Google Colab
         | 
| 92 | 
            +
            If you do not have a setup to run PIFu, we offer Google Colab version to give it a try, allowing you to run PIFu in the cloud, free of charge. Try our Colab demo using the following notebook: 
         | 
| 93 | 
            +
            [](https://colab.research.google.com/drive/1GFSsqP2BWz4gtq0e-nki00ZHSirXwFyY)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            ## Data Generation (Linux Only)
         | 
| 96 | 
            +
            While we are unable to release the full training data due to the restriction of commertial scans, we provide rendering code using free models in [RenderPeople](https://renderpeople.com/free-3d-people/).
         | 
| 97 | 
            +
            This tutorial uses `rp_dennis_posed_004` model. Please download the model from [this link](https://renderpeople.com/sample/free/rp_dennis_posed_004_OBJ.zip) and unzip the content under a folder named `rp_dennis_posed_004_OBJ`. The same process can be applied to other RenderPeople data.
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            Warning: the following code becomes extremely slow without [pyembree](https://github.com/scopatz/pyembree). Please make sure you install pyembree.
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            1. run the following script to compute spherical harmonics coefficients for [precomputed radiance transfer (PRT)](https://sites.fas.harvard.edu/~cs278/papers/prt.pdf). In a nutshell, PRT is used to account for accurate light transport including ambient occlusion without compromising online rendering time, which significantly improves the photorealism compared with [a common sperical harmonics rendering using surface normals](https://cseweb.ucsd.edu/~ravir/papers/envmap/envmap.pdf). This process has to be done once for each obj file.
         | 
| 102 | 
            +
            ```
         | 
| 103 | 
            +
            python -m apps.prt_util -i {path_to_rp_dennis_posed_004_OBJ}
         | 
| 104 | 
            +
            ```
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            2. run the following script. Under the specified data path, the code creates folders named `GEO`, `RENDER`, `MASK`, `PARAM`, `UV_RENDER`, `UV_MASK`, `UV_NORMAL`, and `UV_POS`. Note that you may need to list validation subjects to exclude from training in `{path_to_training_data}/val.txt` (this tutorial has only one subject and leave it empty). If you wish to render images with headless servers equipped with NVIDIA GPU, add -e to enable EGL rendering.
         | 
| 107 | 
            +
            ```
         | 
| 108 | 
            +
            python -m apps.render_data -i {path_to_rp_dennis_posed_004_OBJ} -o {path_to_training_data} [-e]
         | 
| 109 | 
            +
            ```
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            ## Training (Linux Only)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            Warning: the following code becomes extremely slow without [pyembree](https://github.com/scopatz/pyembree). Please make sure you install pyembree.
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            1. run the following script to train the shape module. The intermediate results and checkpoints are saved under `./results` and `./checkpoints` respectively. You can add `--batch_size` and `--num_sample_input` flags to adjust the batch size and the number of sampled points based on available GPU memory.
         | 
| 116 | 
            +
            ```
         | 
| 117 | 
            +
            python -m apps.train_shape --dataroot {path_to_training_data} --random_flip --random_scale --random_trans
         | 
| 118 | 
            +
            ```
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            2. run the following script to train the color module. 
         | 
| 121 | 
            +
            ```
         | 
| 122 | 
            +
            python -m apps.train_color --dataroot {path_to_training_data} --num_sample_inout 0 --num_sample_color 5000 --sigma 0.1 --random_flip --random_scale --random_trans
         | 
| 123 | 
            +
            ```
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            ## Related Research
         | 
| 126 | 
            +
            **[Monocular Real-Time Volumetric Performance Capture (ECCV 2020)](https://project-splinter.github.io/)**  
         | 
| 127 | 
            +
            *Ruilong Li\*, Yuliang Xiu\*, Shunsuke Saito, Zeng Huang, Kyle Olszewski, Hao Li*
         | 
| 128 | 
            +
             | 
| 129 | 
            +
            The first real-time PIFu by accelerating reconstruction and rendering!!
         | 
| 130 | 
            +
             | 
| 131 | 
            +
            **[PIFuHD: Multi-Level Pixel-Aligned Implicit Function for High-Resolution 3D Human Digitization (CVPR 2020)](https://shunsukesaito.github.io/PIFuHD/)**  
         | 
| 132 | 
            +
            *Shunsuke Saito, Tomas Simon, Jason Saragih, Hanbyul Joo*
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            We further improve the quality of reconstruction by leveraging multi-level approach!
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            **[ARCH: Animatable Reconstruction of Clothed Humans (CVPR 2020)](https://arxiv.org/pdf/2004.04572.pdf)**  
         | 
| 137 | 
            +
            *Zeng Huang, Yuanlu Xu, Christoph Lassner, Hao Li, Tony Tung*
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            Learning PIFu in canonical space for animatable avatar generation!
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            **[Robust 3D Self-portraits in Seconds (CVPR 2020)](http://www.liuyebin.com/portrait/portrait.html)**  
         | 
| 142 | 
            +
            *Zhe Li, Tao Yu, Chuanyu Pan, Zerong Zheng, Yebin Liu*
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            They extend PIFu to RGBD + introduce "PIFusion" utilizing PIFu reconstruction for non-rigid fusion.
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            **[Learning to Infer Implicit Surfaces without 3d Supervision (NeurIPS 2019)](http://papers.nips.cc/paper/9039-learning-to-infer-implicit-surfaces-without-3d-supervision.pdf)**  
         | 
| 147 | 
            +
            *Shichen Liu, Shunsuke Saito, Weikai Chen, Hao Li*
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            We answer to the question of "how can we learn implicit function if we don't have 3D ground truth?"
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            **[SiCloPe: Silhouette-Based Clothed People (CVPR 2019, best paper finalist)](https://arxiv.org/pdf/1901.00049.pdf)**  
         | 
| 152 | 
            +
            *Ryota Natsume\*, Shunsuke Saito\*, Zeng Huang, Weikai Chen, Chongyang Ma, Hao Li, Shigeo Morishima*
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            Our first attempt to reconstruct 3D clothed human body with texture from a single image!
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            **[Deep Volumetric Video from Very Sparse Multi-view Performance Capture (ECCV 2018)](http://openaccess.thecvf.com/content_ECCV_2018/papers/Zeng_Huang_Deep_Volumetric_Video_ECCV_2018_paper.pdf)**  
         | 
| 157 | 
            +
            *Zeng Huang, Tianye Li, Weikai Chen, Yajie Zhao, Jun Xing, Chloe LeGendre, Linjie Luo, Chongyang Ma, Hao Li*
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            Implict surface learning for sparse view human performance capture!
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            ------
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            For commercial queries, please contact:
         | 
| 166 | 
            +
             | 
| 167 | 
            +
            Hao Li: [email protected] ccto: [email protected] Baker!!
         | 
    	
        PIFu/apps/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        PIFu/apps/crop_img.py
    ADDED
    
    | @@ -0,0 +1,75 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import cv2
         | 
| 3 | 
            +
            import numpy as np 
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from pathlib import Path
         | 
| 6 | 
            +
            import argparse
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            def get_bbox(msk):
         | 
| 9 | 
            +
                rows = np.any(msk, axis=1)
         | 
| 10 | 
            +
                cols = np.any(msk, axis=0)
         | 
| 11 | 
            +
                rmin, rmax = np.where(rows)[0][[0,-1]]
         | 
| 12 | 
            +
                cmin, cmax = np.where(cols)[0][[0,-1]]
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                return rmin, rmax, cmin, cmax
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def process_img(img, msk, bbox=None):
         | 
| 17 | 
            +
                if bbox is None:
         | 
| 18 | 
            +
                    bbox = get_bbox(msk > 100)
         | 
| 19 | 
            +
                cx = (bbox[3] + bbox[2])//2
         | 
| 20 | 
            +
                cy = (bbox[1] + bbox[0])//2
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                w = img.shape[1]
         | 
| 23 | 
            +
                h = img.shape[0]
         | 
| 24 | 
            +
                height = int(1.138*(bbox[1] - bbox[0]))
         | 
| 25 | 
            +
                hh = height//2
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                # crop
         | 
| 28 | 
            +
                dw = min(cx, w-cx, hh)
         | 
| 29 | 
            +
                if cy-hh < 0:
         | 
| 30 | 
            +
                    img = cv2.copyMakeBorder(img,hh-cy,0,0,0,cv2.BORDER_CONSTANT,value=[0,0,0])    
         | 
| 31 | 
            +
                    msk = cv2.copyMakeBorder(msk,hh-cy,0,0,0,cv2.BORDER_CONSTANT,value=0) 
         | 
| 32 | 
            +
                    cy = hh
         | 
| 33 | 
            +
                if cy+hh > h:
         | 
| 34 | 
            +
                    img = cv2.copyMakeBorder(img,0,cy+hh-h,0,0,cv2.BORDER_CONSTANT,value=[0,0,0])    
         | 
| 35 | 
            +
                    msk = cv2.copyMakeBorder(msk,0,cy+hh-h,0,0,cv2.BORDER_CONSTANT,value=0)    
         | 
| 36 | 
            +
                img = img[cy-hh:(cy+hh),cx-dw:cx+dw,:]
         | 
| 37 | 
            +
                msk = msk[cy-hh:(cy+hh),cx-dw:cx+dw]
         | 
| 38 | 
            +
                dw = img.shape[0] - img.shape[1]
         | 
| 39 | 
            +
                if dw != 0:
         | 
| 40 | 
            +
                    img = cv2.copyMakeBorder(img,0,0,dw//2,dw//2,cv2.BORDER_CONSTANT,value=[0,0,0])    
         | 
| 41 | 
            +
                    msk = cv2.copyMakeBorder(msk,0,0,dw//2,dw//2,cv2.BORDER_CONSTANT,value=0)    
         | 
| 42 | 
            +
                img = cv2.resize(img, (512, 512))
         | 
| 43 | 
            +
                msk = cv2.resize(msk, (512, 512))
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                kernel = np.ones((3,3),np.uint8)
         | 
| 46 | 
            +
                msk = cv2.erode((255*(msk > 100)).astype(np.uint8), kernel, iterations = 1)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                return img, msk
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def main():
         | 
| 51 | 
            +
                '''
         | 
| 52 | 
            +
                given foreground mask, this script crops and resizes an input image and mask for processing.
         | 
| 53 | 
            +
                '''
         | 
| 54 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 55 | 
            +
                parser.add_argument('-i', '--input_image', type=str, help='if the image has alpha channel, it will be used as mask')
         | 
| 56 | 
            +
                parser.add_argument('-m', '--input_mask', type=str)
         | 
| 57 | 
            +
                parser.add_argument('-o', '--out_path', type=str, default='./sample_images')
         | 
| 58 | 
            +
                args = parser.parse_args()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                img = cv2.imread(args.input_image, cv2.IMREAD_UNCHANGED)
         | 
| 61 | 
            +
                if img.shape[2] == 4:
         | 
| 62 | 
            +
                    msk = img[:,:,3:]
         | 
| 63 | 
            +
                    img = img[:,:,:3]
         | 
| 64 | 
            +
                else:
         | 
| 65 | 
            +
                    msk = cv2.imread(args.input_mask, cv2.IMREAD_GRAYSCALE)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                img_new, msk_new = process_img(img, msk)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                img_name = Path(args.input_image).stem
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                cv2.imwrite(os.path.join(args.out_path, img_name + '.png'), img_new)
         | 
| 72 | 
            +
                cv2.imwrite(os.path.join(args.out_path, img_name + '_mask.png'), msk_new)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            if __name__ == "__main__":
         | 
| 75 | 
            +
                main()
         | 
    	
        PIFu/apps/eval.py
    ADDED
    
    | @@ -0,0 +1,123 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
         | 
| 5 | 
            +
            ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            import json
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torch.utils.data import DataLoader
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from lib.options import BaseOptions
         | 
| 14 | 
            +
            from lib.mesh_util import *
         | 
| 15 | 
            +
            from lib.sample_util import *
         | 
| 16 | 
            +
            from lib.train_util import *
         | 
| 17 | 
            +
            from lib.model import *
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from PIL import Image
         | 
| 20 | 
            +
            import torchvision.transforms as transforms
         | 
| 21 | 
            +
            import glob
         | 
| 22 | 
            +
            import tqdm
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            # get options
         | 
| 25 | 
            +
            opt = BaseOptions().parse()
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            class Evaluator:
         | 
| 28 | 
            +
                def __init__(self, opt, projection_mode='orthogonal'):
         | 
| 29 | 
            +
                    self.opt = opt
         | 
| 30 | 
            +
                    self.load_size = self.opt.loadSize
         | 
| 31 | 
            +
                    self.to_tensor = transforms.Compose([
         | 
| 32 | 
            +
                        transforms.Resize(self.load_size),
         | 
| 33 | 
            +
                        transforms.ToTensor(),
         | 
| 34 | 
            +
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         | 
| 35 | 
            +
                    ])
         | 
| 36 | 
            +
                    # set cuda
         | 
| 37 | 
            +
                    cuda = torch.device('cuda:%d' % opt.gpu_id) if torch.cuda.is_available() else torch.device('cpu')
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    # create net
         | 
| 40 | 
            +
                    netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
         | 
| 41 | 
            +
                    print('Using Network: ', netG.name)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    if opt.load_netG_checkpoint_path:
         | 
| 44 | 
            +
                        netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    if opt.load_netC_checkpoint_path is not None:
         | 
| 47 | 
            +
                        print('loading for net C ...', opt.load_netC_checkpoint_path)
         | 
| 48 | 
            +
                        netC = ResBlkPIFuNet(opt).to(device=cuda)
         | 
| 49 | 
            +
                        netC.load_state_dict(torch.load(opt.load_netC_checkpoint_path, map_location=cuda))
         | 
| 50 | 
            +
                    else:
         | 
| 51 | 
            +
                        netC = None
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    os.makedirs(opt.results_path, exist_ok=True)
         | 
| 54 | 
            +
                    os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
         | 
| 57 | 
            +
                    with open(opt_log, 'w') as outfile:
         | 
| 58 | 
            +
                        outfile.write(json.dumps(vars(opt), indent=2))
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    self.cuda = cuda
         | 
| 61 | 
            +
                    self.netG = netG
         | 
| 62 | 
            +
                    self.netC = netC
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def load_image(self, image_path, mask_path):
         | 
| 65 | 
            +
                    # Name
         | 
| 66 | 
            +
                    img_name = os.path.splitext(os.path.basename(image_path))[0]
         | 
| 67 | 
            +
                    # Calib
         | 
| 68 | 
            +
                    B_MIN = np.array([-1, -1, -1])
         | 
| 69 | 
            +
                    B_MAX = np.array([1, 1, 1])
         | 
| 70 | 
            +
                    projection_matrix = np.identity(4)
         | 
| 71 | 
            +
                    projection_matrix[1, 1] = -1
         | 
| 72 | 
            +
                    calib = torch.Tensor(projection_matrix).float()
         | 
| 73 | 
            +
                    # Mask
         | 
| 74 | 
            +
                    mask = Image.open(mask_path).convert('L')
         | 
| 75 | 
            +
                    mask = transforms.Resize(self.load_size)(mask)
         | 
| 76 | 
            +
                    mask = transforms.ToTensor()(mask).float()
         | 
| 77 | 
            +
                    # image
         | 
| 78 | 
            +
                    image = Image.open(image_path).convert('RGB')
         | 
| 79 | 
            +
                    image = self.to_tensor(image)
         | 
| 80 | 
            +
                    image = mask.expand_as(image) * image
         | 
| 81 | 
            +
                    return {
         | 
| 82 | 
            +
                        'name': img_name,
         | 
| 83 | 
            +
                        'img': image.unsqueeze(0),
         | 
| 84 | 
            +
                        'calib': calib.unsqueeze(0),
         | 
| 85 | 
            +
                        'mask': mask.unsqueeze(0),
         | 
| 86 | 
            +
                        'b_min': B_MIN,
         | 
| 87 | 
            +
                        'b_max': B_MAX,
         | 
| 88 | 
            +
                    }
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def eval(self, data, use_octree=False):
         | 
| 91 | 
            +
                    '''
         | 
| 92 | 
            +
                    Evaluate a data point
         | 
| 93 | 
            +
                    :param data: a dict containing at least ['name'], ['image'], ['calib'], ['b_min'] and ['b_max'] tensors.
         | 
| 94 | 
            +
                    :return:
         | 
| 95 | 
            +
                    '''
         | 
| 96 | 
            +
                    opt = self.opt
         | 
| 97 | 
            +
                    with torch.no_grad():
         | 
| 98 | 
            +
                        self.netG.eval()
         | 
| 99 | 
            +
                        if self.netC:
         | 
| 100 | 
            +
                            self.netC.eval()
         | 
| 101 | 
            +
                        save_path = '%s/%s/result_%s.obj' % (opt.results_path, opt.name, data['name'])
         | 
| 102 | 
            +
                        if self.netC:
         | 
| 103 | 
            +
                            gen_mesh_color(opt, self.netG, self.netC, self.cuda, data, save_path, use_octree=use_octree)
         | 
| 104 | 
            +
                        else:
         | 
| 105 | 
            +
                            gen_mesh(opt, self.netG, self.cuda, data, save_path, use_octree=use_octree)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            if __name__ == '__main__':
         | 
| 109 | 
            +
                evaluator = Evaluator(opt)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                test_images = glob.glob(os.path.join(opt.test_folder_path, '*'))
         | 
| 112 | 
            +
                test_images = [f for f in test_images if ('png' in f or 'jpg' in f) and (not 'mask' in f)]
         | 
| 113 | 
            +
                test_masks = [f[:-4]+'_mask.png' for f in test_images]
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                print("num; ", len(test_masks))
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                for image_path, mask_path in tqdm.tqdm(zip(test_images, test_masks)):
         | 
| 118 | 
            +
                    try:
         | 
| 119 | 
            +
                        print(image_path, mask_path)
         | 
| 120 | 
            +
                        data = evaluator.load_image(image_path, mask_path)
         | 
| 121 | 
            +
                        evaluator.eval(data, True)
         | 
| 122 | 
            +
                    except Exception as e:
         | 
| 123 | 
            +
                       print("error:", e.args)
         | 
    	
        PIFu/apps/eval_spaces.py
    ADDED
    
    | @@ -0,0 +1,130 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
         | 
| 5 | 
            +
            ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            import json
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torch.utils.data import DataLoader
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from lib.options import BaseOptions
         | 
| 14 | 
            +
            from lib.mesh_util import *
         | 
| 15 | 
            +
            from lib.sample_util import *
         | 
| 16 | 
            +
            from lib.train_util import *
         | 
| 17 | 
            +
            from lib.model import *
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from PIL import Image
         | 
| 20 | 
            +
            import torchvision.transforms as transforms
         | 
| 21 | 
            +
            import glob
         | 
| 22 | 
            +
            import tqdm
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            import trimesh
         | 
| 25 | 
            +
            # get options
         | 
| 26 | 
            +
            opt = BaseOptions().parse()
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            class Evaluator:
         | 
| 29 | 
            +
                def __init__(self, opt, projection_mode='orthogonal'):
         | 
| 30 | 
            +
                    self.opt = opt
         | 
| 31 | 
            +
                    self.load_size = self.opt.loadSize
         | 
| 32 | 
            +
                    self.to_tensor = transforms.Compose([
         | 
| 33 | 
            +
                        transforms.Resize(self.load_size),
         | 
| 34 | 
            +
                        transforms.ToTensor(),
         | 
| 35 | 
            +
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         | 
| 36 | 
            +
                    ])
         | 
| 37 | 
            +
                    # set cuda
         | 
| 38 | 
            +
                    cuda = torch.device('cuda:%d' % opt.gpu_id) if torch.cuda.is_available() else torch.device('cpu')
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    # create net
         | 
| 41 | 
            +
                    netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
         | 
| 42 | 
            +
                    print('Using Network: ', netG.name)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    if opt.load_netG_checkpoint_path:
         | 
| 45 | 
            +
                        netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    if opt.load_netC_checkpoint_path is not None:
         | 
| 48 | 
            +
                        print('loading for net C ...', opt.load_netC_checkpoint_path)
         | 
| 49 | 
            +
                        netC = ResBlkPIFuNet(opt).to(device=cuda)
         | 
| 50 | 
            +
                        netC.load_state_dict(torch.load(opt.load_netC_checkpoint_path, map_location=cuda))
         | 
| 51 | 
            +
                    else:
         | 
| 52 | 
            +
                        netC = None
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    os.makedirs(opt.results_path, exist_ok=True)
         | 
| 55 | 
            +
                    os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
         | 
| 58 | 
            +
                    with open(opt_log, 'w') as outfile:
         | 
| 59 | 
            +
                        outfile.write(json.dumps(vars(opt), indent=2))
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    self.cuda = cuda
         | 
| 62 | 
            +
                    self.netG = netG
         | 
| 63 | 
            +
                    self.netC = netC
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def load_image(self, image_path, mask_path):
         | 
| 66 | 
            +
                    # Name
         | 
| 67 | 
            +
                    img_name = os.path.splitext(os.path.basename(image_path))[0]
         | 
| 68 | 
            +
                    # Calib
         | 
| 69 | 
            +
                    B_MIN = np.array([-1, -1, -1])
         | 
| 70 | 
            +
                    B_MAX = np.array([1, 1, 1])
         | 
| 71 | 
            +
                    projection_matrix = np.identity(4)
         | 
| 72 | 
            +
                    projection_matrix[1, 1] = -1
         | 
| 73 | 
            +
                    calib = torch.Tensor(projection_matrix).float()
         | 
| 74 | 
            +
                    # Mask
         | 
| 75 | 
            +
                    mask = Image.open(mask_path).convert('L')
         | 
| 76 | 
            +
                    mask = transforms.Resize(self.load_size)(mask)
         | 
| 77 | 
            +
                    mask = transforms.ToTensor()(mask).float()
         | 
| 78 | 
            +
                    # image
         | 
| 79 | 
            +
                    image = Image.open(image_path).convert('RGB')
         | 
| 80 | 
            +
                    image = self.to_tensor(image)
         | 
| 81 | 
            +
                    image = mask.expand_as(image) * image
         | 
| 82 | 
            +
                    return {
         | 
| 83 | 
            +
                        'name': img_name,
         | 
| 84 | 
            +
                        'img': image.unsqueeze(0),
         | 
| 85 | 
            +
                        'calib': calib.unsqueeze(0),
         | 
| 86 | 
            +
                        'mask': mask.unsqueeze(0),
         | 
| 87 | 
            +
                        'b_min': B_MIN,
         | 
| 88 | 
            +
                        'b_max': B_MAX,
         | 
| 89 | 
            +
                    }
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def eval(self, data, use_octree=False):
         | 
| 92 | 
            +
                    '''
         | 
| 93 | 
            +
                    Evaluate a data point
         | 
| 94 | 
            +
                    :param data: a dict containing at least ['name'], ['image'], ['calib'], ['b_min'] and ['b_max'] tensors.
         | 
| 95 | 
            +
                    :return:
         | 
| 96 | 
            +
                    '''
         | 
| 97 | 
            +
                    opt = self.opt
         | 
| 98 | 
            +
                    with torch.no_grad():
         | 
| 99 | 
            +
                        self.netG.eval()
         | 
| 100 | 
            +
                        if self.netC:
         | 
| 101 | 
            +
                            self.netC.eval()
         | 
| 102 | 
            +
                        save_path = '%s/%s/result_%s.obj' % (opt.results_path, opt.name, data['name'])
         | 
| 103 | 
            +
                        if self.netC:
         | 
| 104 | 
            +
                            gen_mesh_color(opt, self.netG, self.netC, self.cuda, data, save_path, use_octree=use_octree)
         | 
| 105 | 
            +
                        else:
         | 
| 106 | 
            +
                            gen_mesh(opt, self.netG, self.cuda, data, save_path, use_octree=use_octree)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
            if __name__ == '__main__':
         | 
| 110 | 
            +
                evaluator = Evaluator(opt)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                results_path = opt.results_path
         | 
| 113 | 
            +
                name = opt.name
         | 
| 114 | 
            +
                test_image_path = opt.img_path
         | 
| 115 | 
            +
                test_mask_path = test_image_path[:-4] +'_mask.png'
         | 
| 116 | 
            +
                test_img_name = os.path.splitext(os.path.basename(test_image_path))[0]
         | 
| 117 | 
            +
                print("test_image: ", test_image_path)
         | 
| 118 | 
            +
                print("test_mask: ", test_mask_path)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                try:
         | 
| 121 | 
            +
                    data = evaluator.load_image(test_image_path, test_mask_path)
         | 
| 122 | 
            +
                    evaluator.eval(data, True)
         | 
| 123 | 
            +
                    mesh = trimesh.load(f'{results_path}/{name}/result_{test_img_name}.obj')
         | 
| 124 | 
            +
                    mesh.apply_transform([[1, 0, 0, 0],
         | 
| 125 | 
            +
                    [0, 1, 0, 0],
         | 
| 126 | 
            +
                    [0, 0, -1, 0],
         | 
| 127 | 
            +
                    [0, 0, 0, 1]])
         | 
| 128 | 
            +
                    mesh.export(file_obj=f'{results_path}/{name}/result_{test_img_name}.glb')
         | 
| 129 | 
            +
                except Exception as e:
         | 
| 130 | 
            +
                    print("error:", e.args)
         | 
    	
        PIFu/apps/prt_util.py
    ADDED
    
    | @@ -0,0 +1,142 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import trimesh
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
            from scipy.special import sph_harm
         | 
| 6 | 
            +
            import argparse
         | 
| 7 | 
            +
            from tqdm import tqdm
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            def factratio(N, D):
         | 
| 10 | 
            +
                if N >= D:
         | 
| 11 | 
            +
                    prod = 1.0
         | 
| 12 | 
            +
                    for i in range(D+1, N+1):
         | 
| 13 | 
            +
                        prod *= i
         | 
| 14 | 
            +
                    return prod
         | 
| 15 | 
            +
                else:
         | 
| 16 | 
            +
                    prod = 1.0
         | 
| 17 | 
            +
                    for i in range(N+1, D+1):
         | 
| 18 | 
            +
                        prod *= i
         | 
| 19 | 
            +
                    return 1.0 / prod
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            def KVal(M, L):
         | 
| 22 | 
            +
                return math.sqrt(((2 * L + 1) / (4 * math.pi)) * (factratio(L - M, L + M)))
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            def AssociatedLegendre(M, L, x):
         | 
| 25 | 
            +
                if M < 0 or M > L or np.max(np.abs(x)) > 1.0:
         | 
| 26 | 
            +
                    return np.zeros_like(x)
         | 
| 27 | 
            +
                
         | 
| 28 | 
            +
                pmm = np.ones_like(x)
         | 
| 29 | 
            +
                if M > 0:
         | 
| 30 | 
            +
                    somx2 = np.sqrt((1.0 + x) * (1.0 - x))
         | 
| 31 | 
            +
                    fact = 1.0
         | 
| 32 | 
            +
                    for i in range(1, M+1):
         | 
| 33 | 
            +
                        pmm = -pmm * fact * somx2
         | 
| 34 | 
            +
                        fact = fact + 2
         | 
| 35 | 
            +
                
         | 
| 36 | 
            +
                if L == M:
         | 
| 37 | 
            +
                    return pmm
         | 
| 38 | 
            +
                else:
         | 
| 39 | 
            +
                    pmmp1 = x * (2 * M + 1) * pmm
         | 
| 40 | 
            +
                    if L == M+1:
         | 
| 41 | 
            +
                        return pmmp1
         | 
| 42 | 
            +
                    else:
         | 
| 43 | 
            +
                        pll = np.zeros_like(x)
         | 
| 44 | 
            +
                        for i in range(M+2, L+1):
         | 
| 45 | 
            +
                            pll = (x * (2 * i - 1) * pmmp1 - (i + M - 1) * pmm) / (i - M)
         | 
| 46 | 
            +
                            pmm = pmmp1
         | 
| 47 | 
            +
                            pmmp1 = pll
         | 
| 48 | 
            +
                        return pll
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def SphericalHarmonic(M, L, theta, phi):
         | 
| 51 | 
            +
                if M > 0:
         | 
| 52 | 
            +
                    return math.sqrt(2.0) * KVal(M, L) * np.cos(M * phi) * AssociatedLegendre(M, L, np.cos(theta))
         | 
| 53 | 
            +
                elif M < 0:
         | 
| 54 | 
            +
                    return math.sqrt(2.0) * KVal(-M, L) * np.sin(-M * phi) * AssociatedLegendre(-M, L, np.cos(theta))
         | 
| 55 | 
            +
                else:
         | 
| 56 | 
            +
                    return KVal(0, L) * AssociatedLegendre(0, L, np.cos(theta))
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            def save_obj(mesh_path, verts):
         | 
| 59 | 
            +
                file = open(mesh_path, 'w')    
         | 
| 60 | 
            +
                for v in verts:
         | 
| 61 | 
            +
                    file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
         | 
| 62 | 
            +
                file.close()
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            def sampleSphericalDirections(n):
         | 
| 65 | 
            +
                xv = np.random.rand(n,n)
         | 
| 66 | 
            +
                yv = np.random.rand(n,n)
         | 
| 67 | 
            +
                theta = np.arccos(1-2 * xv)
         | 
| 68 | 
            +
                phi = 2.0 * math.pi * yv
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                phi = phi.reshape(-1)
         | 
| 71 | 
            +
                theta = theta.reshape(-1)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                vx = -np.sin(theta) * np.cos(phi)
         | 
| 74 | 
            +
                vy = -np.sin(theta) * np.sin(phi)
         | 
| 75 | 
            +
                vz = np.cos(theta)
         | 
| 76 | 
            +
                return np.stack([vx, vy, vz], 1), phi, theta
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            def getSHCoeffs(order, phi, theta):
         | 
| 79 | 
            +
                shs = []
         | 
| 80 | 
            +
                for n in range(0, order+1):
         | 
| 81 | 
            +
                    for m in range(-n,n+1):
         | 
| 82 | 
            +
                        s = SphericalHarmonic(m, n, theta, phi)
         | 
| 83 | 
            +
                        shs.append(s)
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                return np.stack(shs, 1)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            def computePRT(mesh_path, n, order):
         | 
| 88 | 
            +
                mesh = trimesh.load(mesh_path, process=False)
         | 
| 89 | 
            +
                vectors_orig, phi, theta = sampleSphericalDirections(n)
         | 
| 90 | 
            +
                SH_orig = getSHCoeffs(order, phi, theta)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                w = 4.0 * math.pi / (n*n)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                origins = mesh.vertices
         | 
| 95 | 
            +
                normals = mesh.vertex_normals
         | 
| 96 | 
            +
                n_v = origins.shape[0]
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                origins = np.repeat(origins[:,None], n, axis=1).reshape(-1,3)
         | 
| 99 | 
            +
                normals = np.repeat(normals[:,None], n, axis=1).reshape(-1,3)
         | 
| 100 | 
            +
                PRT_all = None
         | 
| 101 | 
            +
                for i in tqdm(range(n)):
         | 
| 102 | 
            +
                    SH = np.repeat(SH_orig[None,(i*n):((i+1)*n)], n_v, axis=0).reshape(-1,SH_orig.shape[1])
         | 
| 103 | 
            +
                    vectors = np.repeat(vectors_orig[None,(i*n):((i+1)*n)], n_v, axis=0).reshape(-1,3)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    dots = (vectors * normals).sum(1)
         | 
| 106 | 
            +
                    front = (dots > 0.0)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    delta = 1e-3*min(mesh.bounding_box.extents)
         | 
| 109 | 
            +
                    hits = mesh.ray.intersects_any(origins + delta * normals, vectors)
         | 
| 110 | 
            +
                    nohits = np.logical_and(front, np.logical_not(hits))
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    PRT = (nohits.astype(np.float) * dots)[:,None] * SH
         | 
| 113 | 
            +
                    
         | 
| 114 | 
            +
                    if PRT_all is not None:
         | 
| 115 | 
            +
                        PRT_all += (PRT.reshape(-1, n, SH.shape[1]).sum(1))
         | 
| 116 | 
            +
                    else:
         | 
| 117 | 
            +
                        PRT_all = (PRT.reshape(-1, n, SH.shape[1]).sum(1))
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                PRT = w * PRT_all
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                # NOTE: trimesh sometimes break the original vertex order, but topology will not change.
         | 
| 122 | 
            +
                # when loading PRT in other program, use the triangle list from trimesh.
         | 
| 123 | 
            +
                return PRT, mesh.faces
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            def testPRT(dir_path, n=40):
         | 
| 126 | 
            +
                if dir_path[-1] == '/':
         | 
| 127 | 
            +
                    dir_path = dir_path[:-1]
         | 
| 128 | 
            +
                sub_name = dir_path.split('/')[-1][:-4]
         | 
| 129 | 
            +
                obj_path = os.path.join(dir_path, sub_name + '_100k.obj')
         | 
| 130 | 
            +
                os.makedirs(os.path.join(dir_path, 'bounce'), exist_ok=True)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                PRT, F = computePRT(obj_path, n, 2)
         | 
| 133 | 
            +
                np.savetxt(os.path.join(dir_path, 'bounce', 'bounce0.txt'), PRT, fmt='%.8f')
         | 
| 134 | 
            +
                np.save(os.path.join(dir_path, 'bounce', 'face.npy'), F)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            if __name__ == '__main__':
         | 
| 137 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 138 | 
            +
                parser.add_argument('-i', '--input', type=str, default='/home/shunsuke/Downloads/rp_dennis_posed_004_OBJ')
         | 
| 139 | 
            +
                parser.add_argument('-n', '--n_sample', type=int, default=40, help='squared root of number of sampling. the higher, the more accurate, but slower')
         | 
| 140 | 
            +
                args = parser.parse_args()
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                testPRT(args.input)
         | 
    	
        PIFu/apps/render_data.py
    ADDED
    
    | @@ -0,0 +1,290 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #from data.config import raw_dataset, render_dataset, archive_dataset, model_list, zip_path
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from lib.renderer.camera import Camera
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from lib.renderer.mesh import load_obj_mesh, compute_tangent, compute_normal, load_obj_mesh_mtl
         | 
| 6 | 
            +
            from lib.renderer.camera import Camera
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
            import cv2
         | 
| 9 | 
            +
            import time
         | 
| 10 | 
            +
            import math
         | 
| 11 | 
            +
            import random
         | 
| 12 | 
            +
            import pyexr
         | 
| 13 | 
            +
            import argparse
         | 
| 14 | 
            +
            from tqdm import tqdm
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def make_rotate(rx, ry, rz):
         | 
| 18 | 
            +
                sinX = np.sin(rx)
         | 
| 19 | 
            +
                sinY = np.sin(ry)
         | 
| 20 | 
            +
                sinZ = np.sin(rz)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                cosX = np.cos(rx)
         | 
| 23 | 
            +
                cosY = np.cos(ry)
         | 
| 24 | 
            +
                cosZ = np.cos(rz)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                Rx = np.zeros((3,3))
         | 
| 27 | 
            +
                Rx[0, 0] = 1.0
         | 
| 28 | 
            +
                Rx[1, 1] = cosX
         | 
| 29 | 
            +
                Rx[1, 2] = -sinX
         | 
| 30 | 
            +
                Rx[2, 1] = sinX
         | 
| 31 | 
            +
                Rx[2, 2] = cosX
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                Ry = np.zeros((3,3))
         | 
| 34 | 
            +
                Ry[0, 0] = cosY
         | 
| 35 | 
            +
                Ry[0, 2] = sinY
         | 
| 36 | 
            +
                Ry[1, 1] = 1.0
         | 
| 37 | 
            +
                Ry[2, 0] = -sinY
         | 
| 38 | 
            +
                Ry[2, 2] = cosY
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                Rz = np.zeros((3,3))
         | 
| 41 | 
            +
                Rz[0, 0] = cosZ
         | 
| 42 | 
            +
                Rz[0, 1] = -sinZ
         | 
| 43 | 
            +
                Rz[1, 0] = sinZ
         | 
| 44 | 
            +
                Rz[1, 1] = cosZ
         | 
| 45 | 
            +
                Rz[2, 2] = 1.0
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                R = np.matmul(np.matmul(Rz,Ry),Rx)
         | 
| 48 | 
            +
                return R
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def rotateSH(SH, R):
         | 
| 51 | 
            +
                SHn = SH
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
                # 1st order
         | 
| 54 | 
            +
                SHn[1] = R[1,1]*SH[1] - R[1,2]*SH[2] + R[1,0]*SH[3]
         | 
| 55 | 
            +
                SHn[2] = -R[2,1]*SH[1] + R[2,2]*SH[2] - R[2,0]*SH[3]
         | 
| 56 | 
            +
                SHn[3] = R[0,1]*SH[1] - R[0,2]*SH[2] + R[0,0]*SH[3]
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                # 2nd order
         | 
| 59 | 
            +
                SHn[4:,0] = rotateBand2(SH[4:,0],R)
         | 
| 60 | 
            +
                SHn[4:,1] = rotateBand2(SH[4:,1],R)
         | 
| 61 | 
            +
                SHn[4:,2] = rotateBand2(SH[4:,2],R)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                return SHn
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            def rotateBand2(x, R):
         | 
| 66 | 
            +
                s_c3 = 0.94617469575
         | 
| 67 | 
            +
                s_c4 = -0.31539156525
         | 
| 68 | 
            +
                s_c5 = 0.54627421529
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                s_c_scale = 1.0/0.91529123286551084
         | 
| 71 | 
            +
                s_c_scale_inv = 0.91529123286551084
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                s_rc2 = 1.5853309190550713*s_c_scale
         | 
| 74 | 
            +
                s_c4_div_c3 = s_c4/s_c3
         | 
| 75 | 
            +
                s_c4_div_c3_x2 = (s_c4/s_c3)*2.0
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                s_scale_dst2 = s_c3 * s_c_scale_inv
         | 
| 78 | 
            +
                s_scale_dst4 = s_c5 * s_c_scale_inv
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                sh0 =  x[3] + x[4] + x[4] - x[1]
         | 
| 81 | 
            +
                sh1 =  x[0] + s_rc2*x[2] +  x[3] + x[4]
         | 
| 82 | 
            +
                sh2 =  x[0]
         | 
| 83 | 
            +
                sh3 = -x[3]
         | 
| 84 | 
            +
                sh4 = -x[1]
         | 
| 85 | 
            +
                
         | 
| 86 | 
            +
                r2x = R[0][0] + R[0][1]
         | 
| 87 | 
            +
                r2y = R[1][0] + R[1][1]
         | 
| 88 | 
            +
                r2z = R[2][0] + R[2][1]
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
                r3x = R[0][0] + R[0][2]
         | 
| 91 | 
            +
                r3y = R[1][0] + R[1][2]
         | 
| 92 | 
            +
                r3z = R[2][0] + R[2][2]
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                r4x = R[0][1] + R[0][2]
         | 
| 95 | 
            +
                r4y = R[1][1] + R[1][2]
         | 
| 96 | 
            +
                r4z = R[2][1] + R[2][2]
         | 
| 97 | 
            +
                
         | 
| 98 | 
            +
                sh0_x = sh0 * R[0][0]
         | 
| 99 | 
            +
                sh0_y = sh0 * R[1][0]
         | 
| 100 | 
            +
                d0 = sh0_x * R[1][0]
         | 
| 101 | 
            +
                d1 = sh0_y * R[2][0]
         | 
| 102 | 
            +
                d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3)
         | 
| 103 | 
            +
                d3 = sh0_x * R[2][0]
         | 
| 104 | 
            +
                d4 = sh0_x * R[0][0] - sh0_y * R[1][0]
         | 
| 105 | 
            +
                
         | 
| 106 | 
            +
                sh1_x = sh1 * R[0][2]
         | 
| 107 | 
            +
                sh1_y = sh1 * R[1][2]
         | 
| 108 | 
            +
                d0 += sh1_x * R[1][2]
         | 
| 109 | 
            +
                d1 += sh1_y * R[2][2]
         | 
| 110 | 
            +
                d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3)
         | 
| 111 | 
            +
                d3 += sh1_x * R[2][2]
         | 
| 112 | 
            +
                d4 += sh1_x * R[0][2] - sh1_y * R[1][2]
         | 
| 113 | 
            +
                
         | 
| 114 | 
            +
                sh2_x = sh2 * r2x
         | 
| 115 | 
            +
                sh2_y = sh2 * r2y
         | 
| 116 | 
            +
                d0 += sh2_x * r2y
         | 
| 117 | 
            +
                d1 += sh2_y * r2z
         | 
| 118 | 
            +
                d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2)
         | 
| 119 | 
            +
                d3 += sh2_x * r2z
         | 
| 120 | 
            +
                d4 += sh2_x * r2x - sh2_y * r2y
         | 
| 121 | 
            +
                
         | 
| 122 | 
            +
                sh3_x = sh3 * r3x
         | 
| 123 | 
            +
                sh3_y = sh3 * r3y
         | 
| 124 | 
            +
                d0 += sh3_x * r3y
         | 
| 125 | 
            +
                d1 += sh3_y * r3z
         | 
| 126 | 
            +
                d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2)
         | 
| 127 | 
            +
                d3 += sh3_x * r3z
         | 
| 128 | 
            +
                d4 += sh3_x * r3x - sh3_y * r3y
         | 
| 129 | 
            +
                
         | 
| 130 | 
            +
                sh4_x = sh4 * r4x
         | 
| 131 | 
            +
                sh4_y = sh4 * r4y
         | 
| 132 | 
            +
                d0 += sh4_x * r4y
         | 
| 133 | 
            +
                d1 += sh4_y * r4z
         | 
| 134 | 
            +
                d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2)
         | 
| 135 | 
            +
                d3 += sh4_x * r4z
         | 
| 136 | 
            +
                d4 += sh4_x * r4x - sh4_y * r4y
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                dst = x
         | 
| 139 | 
            +
                dst[0] = d0
         | 
| 140 | 
            +
                dst[1] = -d1
         | 
| 141 | 
            +
                dst[2] = d2 * s_scale_dst2
         | 
| 142 | 
            +
                dst[3] = -d3
         | 
| 143 | 
            +
                dst[4] = d4 * s_scale_dst4
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                return dst
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            def render_prt_ortho(out_path, folder_name, subject_name, shs, rndr, rndr_uv, im_size, angl_step=4, n_light=1, pitch=[0]):
         | 
| 148 | 
            +
                cam = Camera(width=im_size, height=im_size)
         | 
| 149 | 
            +
                cam.ortho_ratio = 0.4 * (512 / im_size)
         | 
| 150 | 
            +
                cam.near = -100
         | 
| 151 | 
            +
                cam.far = 100
         | 
| 152 | 
            +
                cam.sanity_check()
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                # set path for obj, prt
         | 
| 155 | 
            +
                mesh_file = os.path.join(folder_name, subject_name + '_100k.obj')
         | 
| 156 | 
            +
                if not os.path.exists(mesh_file):
         | 
| 157 | 
            +
                    print('ERROR: obj file does not exist!!', mesh_file)
         | 
| 158 | 
            +
                    return 
         | 
| 159 | 
            +
                prt_file = os.path.join(folder_name, 'bounce', 'bounce0.txt')
         | 
| 160 | 
            +
                if not os.path.exists(prt_file):
         | 
| 161 | 
            +
                    print('ERROR: prt file does not exist!!!', prt_file)
         | 
| 162 | 
            +
                    return
         | 
| 163 | 
            +
                face_prt_file = os.path.join(folder_name, 'bounce', 'face.npy')
         | 
| 164 | 
            +
                if not os.path.exists(face_prt_file):
         | 
| 165 | 
            +
                    print('ERROR: face prt file does not exist!!!', prt_file)
         | 
| 166 | 
            +
                    return
         | 
| 167 | 
            +
                text_file = os.path.join(folder_name, 'tex', subject_name + '_dif_2k.jpg')
         | 
| 168 | 
            +
                if not os.path.exists(text_file):
         | 
| 169 | 
            +
                    print('ERROR: dif file does not exist!!', text_file)
         | 
| 170 | 
            +
                    return             
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                texture_image = cv2.imread(text_file)
         | 
| 173 | 
            +
                texture_image = cv2.cvtColor(texture_image, cv2.COLOR_BGR2RGB)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                vertices, faces, normals, faces_normals, textures, face_textures = load_obj_mesh(mesh_file, with_normal=True, with_texture=True)
         | 
| 176 | 
            +
                vmin = vertices.min(0)
         | 
| 177 | 
            +
                vmax = vertices.max(0)
         | 
| 178 | 
            +
                up_axis = 1 if (vmax-vmin).argmax() == 1 else 2
         | 
| 179 | 
            +
                
         | 
| 180 | 
            +
                vmed = np.median(vertices, 0)
         | 
| 181 | 
            +
                vmed[up_axis] = 0.5*(vmax[up_axis]+vmin[up_axis])
         | 
| 182 | 
            +
                y_scale = 180/(vmax[up_axis] - vmin[up_axis])
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                rndr.set_norm_mat(y_scale, vmed)
         | 
| 185 | 
            +
                rndr_uv.set_norm_mat(y_scale, vmed)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                tan, bitan = compute_tangent(vertices, faces, normals, textures, face_textures)
         | 
| 188 | 
            +
                prt = np.loadtxt(prt_file)
         | 
| 189 | 
            +
                face_prt = np.load(face_prt_file)
         | 
| 190 | 
            +
                rndr.set_mesh(vertices, faces, normals, faces_normals, textures, face_textures, prt, face_prt, tan, bitan)    
         | 
| 191 | 
            +
                rndr.set_albedo(texture_image)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                rndr_uv.set_mesh(vertices, faces, normals, faces_normals, textures, face_textures, prt, face_prt, tan, bitan)   
         | 
| 194 | 
            +
                rndr_uv.set_albedo(texture_image)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                os.makedirs(os.path.join(out_path, 'GEO', 'OBJ', subject_name),exist_ok=True)
         | 
| 197 | 
            +
                os.makedirs(os.path.join(out_path, 'PARAM', subject_name),exist_ok=True)
         | 
| 198 | 
            +
                os.makedirs(os.path.join(out_path, 'RENDER', subject_name),exist_ok=True)
         | 
| 199 | 
            +
                os.makedirs(os.path.join(out_path, 'MASK', subject_name),exist_ok=True)
         | 
| 200 | 
            +
                os.makedirs(os.path.join(out_path, 'UV_RENDER', subject_name),exist_ok=True)
         | 
| 201 | 
            +
                os.makedirs(os.path.join(out_path, 'UV_MASK', subject_name),exist_ok=True)
         | 
| 202 | 
            +
                os.makedirs(os.path.join(out_path, 'UV_POS', subject_name),exist_ok=True)
         | 
| 203 | 
            +
                os.makedirs(os.path.join(out_path, 'UV_NORMAL', subject_name),exist_ok=True)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                if not os.path.exists(os.path.join(out_path, 'val.txt')):
         | 
| 206 | 
            +
                    f = open(os.path.join(out_path, 'val.txt'), 'w')
         | 
| 207 | 
            +
                    f.close()
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                # copy obj file
         | 
| 210 | 
            +
                cmd = 'cp %s %s' % (mesh_file, os.path.join(out_path, 'GEO', 'OBJ', subject_name))
         | 
| 211 | 
            +
                print(cmd)
         | 
| 212 | 
            +
                os.system(cmd)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                for p in pitch:
         | 
| 215 | 
            +
                    for y in tqdm(range(0, 360, angl_step)):
         | 
| 216 | 
            +
                        R = np.matmul(make_rotate(math.radians(p), 0, 0), make_rotate(0, math.radians(y), 0))
         | 
| 217 | 
            +
                        if up_axis == 2:
         | 
| 218 | 
            +
                            R = np.matmul(R, make_rotate(math.radians(90),0,0))
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                        rndr.rot_matrix = R
         | 
| 221 | 
            +
                        rndr_uv.rot_matrix = R
         | 
| 222 | 
            +
                        rndr.set_camera(cam)
         | 
| 223 | 
            +
                        rndr_uv.set_camera(cam)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                        for j in range(n_light):
         | 
| 226 | 
            +
                            sh_id = random.randint(0,shs.shape[0]-1)
         | 
| 227 | 
            +
                            sh = shs[sh_id]
         | 
| 228 | 
            +
                            sh_angle = 0.2*np.pi*(random.random()-0.5)
         | 
| 229 | 
            +
                            sh = rotateSH(sh, make_rotate(0, sh_angle, 0).T)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                            dic = {'sh': sh, 'ortho_ratio': cam.ortho_ratio, 'scale': y_scale, 'center': vmed, 'R': R}
         | 
| 232 | 
            +
                            
         | 
| 233 | 
            +
                            rndr.set_sh(sh)        
         | 
| 234 | 
            +
                            rndr.analytic = False
         | 
| 235 | 
            +
                            rndr.use_inverse_depth = False
         | 
| 236 | 
            +
                            rndr.display()
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                            out_all_f = rndr.get_color(0)
         | 
| 239 | 
            +
                            out_mask = out_all_f[:,:,3]
         | 
| 240 | 
            +
                            out_all_f = cv2.cvtColor(out_all_f, cv2.COLOR_RGBA2BGR)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                            np.save(os.path.join(out_path, 'PARAM', subject_name, '%d_%d_%02d.npy'%(y,p,j)),dic)
         | 
| 243 | 
            +
                            cv2.imwrite(os.path.join(out_path, 'RENDER', subject_name, '%d_%d_%02d.jpg'%(y,p,j)),255.0*out_all_f)
         | 
| 244 | 
            +
                            cv2.imwrite(os.path.join(out_path, 'MASK', subject_name, '%d_%d_%02d.png'%(y,p,j)),255.0*out_mask)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                            rndr_uv.set_sh(sh)
         | 
| 247 | 
            +
                            rndr_uv.analytic = False
         | 
| 248 | 
            +
                            rndr_uv.use_inverse_depth = False
         | 
| 249 | 
            +
                            rndr_uv.display()
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                            uv_color = rndr_uv.get_color(0)
         | 
| 252 | 
            +
                            uv_color = cv2.cvtColor(uv_color, cv2.COLOR_RGBA2BGR)
         | 
| 253 | 
            +
                            cv2.imwrite(os.path.join(out_path, 'UV_RENDER', subject_name, '%d_%d_%02d.jpg'%(y,p,j)),255.0*uv_color)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                            if y == 0 and j == 0 and p == pitch[0]:
         | 
| 256 | 
            +
                                uv_pos = rndr_uv.get_color(1)
         | 
| 257 | 
            +
                                uv_mask = uv_pos[:,:,3]
         | 
| 258 | 
            +
                                cv2.imwrite(os.path.join(out_path, 'UV_MASK', subject_name, '00.png'),255.0*uv_mask)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                                data = {'default': uv_pos[:,:,:3]} # default is a reserved name
         | 
| 261 | 
            +
                                pyexr.write(os.path.join(out_path, 'UV_POS', subject_name, '00.exr'), data) 
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                                uv_nml = rndr_uv.get_color(2)
         | 
| 264 | 
            +
                                uv_nml = cv2.cvtColor(uv_nml, cv2.COLOR_RGBA2BGR)
         | 
| 265 | 
            +
                                cv2.imwrite(os.path.join(out_path, 'UV_NORMAL', subject_name, '00.png'),255.0*uv_nml)
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            if __name__ == '__main__':
         | 
| 269 | 
            +
                shs = np.load('./env_sh.npy')
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 272 | 
            +
                parser.add_argument('-i', '--input', type=str, default='/home/shunsuke/Downloads/rp_dennis_posed_004_OBJ')
         | 
| 273 | 
            +
                parser.add_argument('-o', '--out_dir', type=str, default='/home/shunsuke/Documents/hf_human')
         | 
| 274 | 
            +
                parser.add_argument('-m', '--ms_rate', type=int, default=1, help='higher ms rate results in less aliased output. MESA renderer only supports ms_rate=1.')
         | 
| 275 | 
            +
                parser.add_argument('-e', '--egl',  action='store_true', help='egl rendering option. use this when rendering with headless server with NVIDIA GPU')
         | 
| 276 | 
            +
                parser.add_argument('-s', '--size',  type=int, default=512, help='rendering image size')
         | 
| 277 | 
            +
                args = parser.parse_args()
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                # NOTE: GL context has to be created before any other OpenGL function loads.
         | 
| 280 | 
            +
                from lib.renderer.gl.init_gl import initialize_GL_context
         | 
| 281 | 
            +
                initialize_GL_context(width=args.size, height=args.size, egl=args.egl)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                from lib.renderer.gl.prt_render import PRTRender
         | 
| 284 | 
            +
                rndr = PRTRender(width=args.size, height=args.size, ms_rate=args.ms_rate, egl=args.egl)
         | 
| 285 | 
            +
                rndr_uv = PRTRender(width=args.size, height=args.size, uv_mode=True, egl=args.egl)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                if args.input[-1] == '/':
         | 
| 288 | 
            +
                    args.input = args.input[:-1]
         | 
| 289 | 
            +
                subject_name = args.input.split('/')[-1][:-4]
         | 
| 290 | 
            +
                render_prt_ortho(args.out_dir, args.input, subject_name, shs, rndr, rndr_uv, args.size, 1, 1, pitch=[0])
         | 
    	
        PIFu/apps/train_color.py
    ADDED
    
    | @@ -0,0 +1,191 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
         | 
| 5 | 
            +
            ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            import json
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import cv2
         | 
| 11 | 
            +
            import random
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import torch.nn as nn
         | 
| 14 | 
            +
            from torch.utils.data import DataLoader
         | 
| 15 | 
            +
            from tqdm import tqdm
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from lib.options import BaseOptions
         | 
| 18 | 
            +
            from lib.mesh_util import *
         | 
| 19 | 
            +
            from lib.sample_util import *
         | 
| 20 | 
            +
            from lib.train_util import *
         | 
| 21 | 
            +
            from lib.data import *
         | 
| 22 | 
            +
            from lib.model import *
         | 
| 23 | 
            +
            from lib.geometry import index
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            # get options
         | 
| 26 | 
            +
            opt = BaseOptions().parse()
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            def train_color(opt):
         | 
| 29 | 
            +
                # set cuda
         | 
| 30 | 
            +
                cuda = torch.device('cuda:%d' % opt.gpu_id)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                train_dataset = TrainDataset(opt, phase='train')
         | 
| 33 | 
            +
                test_dataset = TrainDataset(opt, phase='test')
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                projection_mode = train_dataset.projection_mode
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                # create data loader
         | 
| 38 | 
            +
                train_data_loader = DataLoader(train_dataset,
         | 
| 39 | 
            +
                                               batch_size=opt.batch_size, shuffle=not opt.serial_batches,
         | 
| 40 | 
            +
                                               num_workers=opt.num_threads, pin_memory=opt.pin_memory)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                print('train data size: ', len(train_data_loader))
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # NOTE: batch size should be 1 and use all the points for evaluation
         | 
| 45 | 
            +
                test_data_loader = DataLoader(test_dataset,
         | 
| 46 | 
            +
                                              batch_size=1, shuffle=False,
         | 
| 47 | 
            +
                                              num_workers=opt.num_threads, pin_memory=opt.pin_memory)
         | 
| 48 | 
            +
                print('test data size: ', len(test_data_loader))
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                # create net
         | 
| 51 | 
            +
                netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                lr = opt.learning_rate
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                # Always use resnet for color regression
         | 
| 56 | 
            +
                netC = ResBlkPIFuNet(opt).to(device=cuda)
         | 
| 57 | 
            +
                optimizerC = torch.optim.Adam(netC.parameters(), lr=opt.learning_rate)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def set_train():
         | 
| 60 | 
            +
                    netG.eval()
         | 
| 61 | 
            +
                    netC.train()
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def set_eval():
         | 
| 64 | 
            +
                    netG.eval()
         | 
| 65 | 
            +
                    netC.eval()
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                print('Using NetworkG: ', netG.name, 'networkC: ', netC.name)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                # load checkpoints
         | 
| 70 | 
            +
                if opt.load_netG_checkpoint_path is not None:
         | 
| 71 | 
            +
                    print('loading for net G ...', opt.load_netG_checkpoint_path)
         | 
| 72 | 
            +
                    netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
         | 
| 73 | 
            +
                else:
         | 
| 74 | 
            +
                    model_path_G = '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name)
         | 
| 75 | 
            +
                    print('loading for net G ...', model_path_G)
         | 
| 76 | 
            +
                    netG.load_state_dict(torch.load(model_path_G, map_location=cuda))
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                if opt.load_netC_checkpoint_path is not None:
         | 
| 79 | 
            +
                    print('loading for net C ...', opt.load_netC_checkpoint_path)
         | 
| 80 | 
            +
                    netC.load_state_dict(torch.load(opt.load_netC_checkpoint_path, map_location=cuda))
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                if opt.continue_train:
         | 
| 83 | 
            +
                    if opt.resume_epoch < 0:
         | 
| 84 | 
            +
                        model_path_C = '%s/%s/netC_latest' % (opt.checkpoints_path, opt.name)
         | 
| 85 | 
            +
                    else:
         | 
| 86 | 
            +
                        model_path_C = '%s/%s/netC_epoch_%d' % (opt.checkpoints_path, opt.name, opt.resume_epoch)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    print('Resuming from ', model_path_C)
         | 
| 89 | 
            +
                    netC.load_state_dict(torch.load(model_path_C, map_location=cuda))
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                os.makedirs(opt.checkpoints_path, exist_ok=True)
         | 
| 92 | 
            +
                os.makedirs(opt.results_path, exist_ok=True)
         | 
| 93 | 
            +
                os.makedirs('%s/%s' % (opt.checkpoints_path, opt.name), exist_ok=True)
         | 
| 94 | 
            +
                os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
         | 
| 97 | 
            +
                with open(opt_log, 'w') as outfile:
         | 
| 98 | 
            +
                    outfile.write(json.dumps(vars(opt), indent=2))
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                # training
         | 
| 101 | 
            +
                start_epoch = 0 if not opt.continue_train else max(opt.resume_epoch,0)
         | 
| 102 | 
            +
                for epoch in range(start_epoch, opt.num_epoch):
         | 
| 103 | 
            +
                    epoch_start_time = time.time()
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    set_train()
         | 
| 106 | 
            +
                    iter_data_time = time.time()
         | 
| 107 | 
            +
                    for train_idx, train_data in enumerate(train_data_loader):
         | 
| 108 | 
            +
                        iter_start_time = time.time()
         | 
| 109 | 
            +
                        # retrieve the data
         | 
| 110 | 
            +
                        image_tensor = train_data['img'].to(device=cuda)
         | 
| 111 | 
            +
                        calib_tensor = train_data['calib'].to(device=cuda)
         | 
| 112 | 
            +
                        color_sample_tensor = train_data['color_samples'].to(device=cuda)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                        image_tensor, calib_tensor = reshape_multiview_tensors(image_tensor, calib_tensor)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                        if opt.num_views > 1:
         | 
| 117 | 
            +
                            color_sample_tensor = reshape_sample_tensor(color_sample_tensor, opt.num_views)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        rgb_tensor = train_data['rgbs'].to(device=cuda)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                        with torch.no_grad():
         | 
| 122 | 
            +
                            netG.filter(image_tensor)
         | 
| 123 | 
            +
                        resC, error = netC.forward(image_tensor, netG.get_im_feat(), color_sample_tensor, calib_tensor, labels=rgb_tensor)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                        optimizerC.zero_grad()
         | 
| 126 | 
            +
                        error.backward()
         | 
| 127 | 
            +
                        optimizerC.step()
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                        iter_net_time = time.time()
         | 
| 130 | 
            +
                        eta = ((iter_net_time - epoch_start_time) / (train_idx + 1)) * len(train_data_loader) - (
         | 
| 131 | 
            +
                                iter_net_time - epoch_start_time)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                        if train_idx % opt.freq_plot == 0:
         | 
| 134 | 
            +
                            print(
         | 
| 135 | 
            +
                                'Name: {0} | Epoch: {1} | {2}/{3} | Err: {4:.06f} | LR: {5:.06f} | dataT: {6:.05f} | netT: {7:.05f} | ETA: {8:02d}:{9:02d}'.format(
         | 
| 136 | 
            +
                                    opt.name, epoch, train_idx, len(train_data_loader),
         | 
| 137 | 
            +
                                    error.item(),
         | 
| 138 | 
            +
                                    lr,
         | 
| 139 | 
            +
                                    iter_start_time - iter_data_time,
         | 
| 140 | 
            +
                                    iter_net_time - iter_start_time, int(eta // 60),
         | 
| 141 | 
            +
                                    int(eta - 60 * (eta // 60))))
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                        if train_idx % opt.freq_save == 0 and train_idx != 0:
         | 
| 144 | 
            +
                            torch.save(netC.state_dict(), '%s/%s/netC_latest' % (opt.checkpoints_path, opt.name))
         | 
| 145 | 
            +
                            torch.save(netC.state_dict(), '%s/%s/netC_epoch_%d' % (opt.checkpoints_path, opt.name, epoch))
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        if train_idx % opt.freq_save_ply == 0:
         | 
| 148 | 
            +
                            save_path = '%s/%s/pred_col.ply' % (opt.results_path, opt.name)
         | 
| 149 | 
            +
                            rgb = resC[0].transpose(0, 1).cpu() * 0.5 + 0.5
         | 
| 150 | 
            +
                            points = color_sample_tensor[0].transpose(0, 1).cpu()
         | 
| 151 | 
            +
                            save_samples_rgb(save_path, points.detach().numpy(), rgb.detach().numpy())
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                        iter_data_time = time.time()
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    #### test
         | 
| 156 | 
            +
                    with torch.no_grad():
         | 
| 157 | 
            +
                        set_eval()
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                        if not opt.no_num_eval:
         | 
| 160 | 
            +
                            test_losses = {}
         | 
| 161 | 
            +
                            print('calc error (test) ...')
         | 
| 162 | 
            +
                            test_color_error = calc_error_color(opt, netG, netC, cuda, test_dataset, 100)
         | 
| 163 | 
            +
                            print('eval test | color error:', test_color_error)
         | 
| 164 | 
            +
                            test_losses['test_color'] = test_color_error
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                            print('calc error (train) ...')
         | 
| 167 | 
            +
                            train_dataset.is_train = False
         | 
| 168 | 
            +
                            train_color_error = calc_error_color(opt, netG, netC, cuda, train_dataset, 100)
         | 
| 169 | 
            +
                            train_dataset.is_train = True
         | 
| 170 | 
            +
                            print('eval train | color error:', train_color_error)
         | 
| 171 | 
            +
                            test_losses['train_color'] = train_color_error
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        if not opt.no_gen_mesh:
         | 
| 174 | 
            +
                            print('generate mesh (test) ...')
         | 
| 175 | 
            +
                            for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
         | 
| 176 | 
            +
                                test_data = random.choice(test_dataset)
         | 
| 177 | 
            +
                                save_path = '%s/%s/test_eval_epoch%d_%s.obj' % (
         | 
| 178 | 
            +
                                    opt.results_path, opt.name, epoch, test_data['name'])
         | 
| 179 | 
            +
                                gen_mesh_color(opt, netG, netC, cuda, test_data, save_path)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                            print('generate mesh (train) ...')
         | 
| 182 | 
            +
                            train_dataset.is_train = False
         | 
| 183 | 
            +
                            for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
         | 
| 184 | 
            +
                                train_data = random.choice(train_dataset)
         | 
| 185 | 
            +
                                save_path = '%s/%s/train_eval_epoch%d_%s.obj' % (
         | 
| 186 | 
            +
                                    opt.results_path, opt.name, epoch, train_data['name'])
         | 
| 187 | 
            +
                                gen_mesh_color(opt, netG, netC, cuda, train_data, save_path)
         | 
| 188 | 
            +
                            train_dataset.is_train = True
         | 
| 189 | 
            +
             | 
| 190 | 
            +
            if __name__ == '__main__':
         | 
| 191 | 
            +
                train_color(opt)
         | 
    	
        PIFu/apps/train_shape.py
    ADDED
    
    | @@ -0,0 +1,183 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
         | 
| 5 | 
            +
            ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            import json
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import cv2
         | 
| 11 | 
            +
            import random
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from torch.utils.data import DataLoader
         | 
| 14 | 
            +
            from tqdm import tqdm
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from lib.options import BaseOptions
         | 
| 17 | 
            +
            from lib.mesh_util import *
         | 
| 18 | 
            +
            from lib.sample_util import *
         | 
| 19 | 
            +
            from lib.train_util import *
         | 
| 20 | 
            +
            from lib.data import *
         | 
| 21 | 
            +
            from lib.model import *
         | 
| 22 | 
            +
            from lib.geometry import index
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            # get options
         | 
| 25 | 
            +
            opt = BaseOptions().parse()
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def train(opt):
         | 
| 28 | 
            +
                # set cuda
         | 
| 29 | 
            +
                cuda = torch.device('cuda:%d' % opt.gpu_id)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                train_dataset = TrainDataset(opt, phase='train')
         | 
| 32 | 
            +
                test_dataset = TrainDataset(opt, phase='test')
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                projection_mode = train_dataset.projection_mode
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                # create data loader
         | 
| 37 | 
            +
                train_data_loader = DataLoader(train_dataset,
         | 
| 38 | 
            +
                                               batch_size=opt.batch_size, shuffle=not opt.serial_batches,
         | 
| 39 | 
            +
                                               num_workers=opt.num_threads, pin_memory=opt.pin_memory)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                print('train data size: ', len(train_data_loader))
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                # NOTE: batch size should be 1 and use all the points for evaluation
         | 
| 44 | 
            +
                test_data_loader = DataLoader(test_dataset,
         | 
| 45 | 
            +
                                              batch_size=1, shuffle=False,
         | 
| 46 | 
            +
                                              num_workers=opt.num_threads, pin_memory=opt.pin_memory)
         | 
| 47 | 
            +
                print('test data size: ', len(test_data_loader))
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                # create net
         | 
| 50 | 
            +
                netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
         | 
| 51 | 
            +
                optimizerG = torch.optim.RMSprop(netG.parameters(), lr=opt.learning_rate, momentum=0, weight_decay=0)
         | 
| 52 | 
            +
                lr = opt.learning_rate
         | 
| 53 | 
            +
                print('Using Network: ', netG.name)
         | 
| 54 | 
            +
                
         | 
| 55 | 
            +
                def set_train():
         | 
| 56 | 
            +
                    netG.train()
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def set_eval():
         | 
| 59 | 
            +
                    netG.eval()
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                # load checkpoints
         | 
| 62 | 
            +
                if opt.load_netG_checkpoint_path is not None:
         | 
| 63 | 
            +
                    print('loading for net G ...', opt.load_netG_checkpoint_path)
         | 
| 64 | 
            +
                    netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                if opt.continue_train:
         | 
| 67 | 
            +
                    if opt.resume_epoch < 0:
         | 
| 68 | 
            +
                        model_path = '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name)
         | 
| 69 | 
            +
                    else:
         | 
| 70 | 
            +
                        model_path = '%s/%s/netG_epoch_%d' % (opt.checkpoints_path, opt.name, opt.resume_epoch)
         | 
| 71 | 
            +
                    print('Resuming from ', model_path)
         | 
| 72 | 
            +
                    netG.load_state_dict(torch.load(model_path, map_location=cuda))
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                os.makedirs(opt.checkpoints_path, exist_ok=True)
         | 
| 75 | 
            +
                os.makedirs(opt.results_path, exist_ok=True)
         | 
| 76 | 
            +
                os.makedirs('%s/%s' % (opt.checkpoints_path, opt.name), exist_ok=True)
         | 
| 77 | 
            +
                os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
         | 
| 80 | 
            +
                with open(opt_log, 'w') as outfile:
         | 
| 81 | 
            +
                    outfile.write(json.dumps(vars(opt), indent=2))
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                # training
         | 
| 84 | 
            +
                start_epoch = 0 if not opt.continue_train else max(opt.resume_epoch,0)
         | 
| 85 | 
            +
                for epoch in range(start_epoch, opt.num_epoch):
         | 
| 86 | 
            +
                    epoch_start_time = time.time()
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    set_train()
         | 
| 89 | 
            +
                    iter_data_time = time.time()
         | 
| 90 | 
            +
                    for train_idx, train_data in enumerate(train_data_loader):
         | 
| 91 | 
            +
                        iter_start_time = time.time()
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                        # retrieve the data
         | 
| 94 | 
            +
                        image_tensor = train_data['img'].to(device=cuda)
         | 
| 95 | 
            +
                        calib_tensor = train_data['calib'].to(device=cuda)
         | 
| 96 | 
            +
                        sample_tensor = train_data['samples'].to(device=cuda)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                        image_tensor, calib_tensor = reshape_multiview_tensors(image_tensor, calib_tensor)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                        if opt.num_views > 1:
         | 
| 101 | 
            +
                            sample_tensor = reshape_sample_tensor(sample_tensor, opt.num_views)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                        label_tensor = train_data['labels'].to(device=cuda)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                        res, error = netG.forward(image_tensor, sample_tensor, calib_tensor, labels=label_tensor)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                        optimizerG.zero_grad()
         | 
| 108 | 
            +
                        error.backward()
         | 
| 109 | 
            +
                        optimizerG.step()
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                        iter_net_time = time.time()
         | 
| 112 | 
            +
                        eta = ((iter_net_time - epoch_start_time) / (train_idx + 1)) * len(train_data_loader) - (
         | 
| 113 | 
            +
                                iter_net_time - epoch_start_time)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                        if train_idx % opt.freq_plot == 0:
         | 
| 116 | 
            +
                            print(
         | 
| 117 | 
            +
                                'Name: {0} | Epoch: {1} | {2}/{3} | Err: {4:.06f} | LR: {5:.06f} | Sigma: {6:.02f} | dataT: {7:.05f} | netT: {8:.05f} | ETA: {9:02d}:{10:02d}'.format(
         | 
| 118 | 
            +
                                    opt.name, epoch, train_idx, len(train_data_loader), error.item(), lr, opt.sigma,
         | 
| 119 | 
            +
                                                                                        iter_start_time - iter_data_time,
         | 
| 120 | 
            +
                                                                                        iter_net_time - iter_start_time, int(eta // 60),
         | 
| 121 | 
            +
                                    int(eta - 60 * (eta // 60))))
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                        if train_idx % opt.freq_save == 0 and train_idx != 0:
         | 
| 124 | 
            +
                            torch.save(netG.state_dict(), '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name))
         | 
| 125 | 
            +
                            torch.save(netG.state_dict(), '%s/%s/netG_epoch_%d' % (opt.checkpoints_path, opt.name, epoch))
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        if train_idx % opt.freq_save_ply == 0:
         | 
| 128 | 
            +
                            save_path = '%s/%s/pred.ply' % (opt.results_path, opt.name)
         | 
| 129 | 
            +
                            r = res[0].cpu()
         | 
| 130 | 
            +
                            points = sample_tensor[0].transpose(0, 1).cpu()
         | 
| 131 | 
            +
                            save_samples_truncted_prob(save_path, points.detach().numpy(), r.detach().numpy())
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                        iter_data_time = time.time()
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # update learning rate
         | 
| 136 | 
            +
                    lr = adjust_learning_rate(optimizerG, epoch, lr, opt.schedule, opt.gamma)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    #### test
         | 
| 139 | 
            +
                    with torch.no_grad():
         | 
| 140 | 
            +
                        set_eval()
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                        if not opt.no_num_eval:
         | 
| 143 | 
            +
                            test_losses = {}
         | 
| 144 | 
            +
                            print('calc error (test) ...')
         | 
| 145 | 
            +
                            test_errors = calc_error(opt, netG, cuda, test_dataset, 100)
         | 
| 146 | 
            +
                            print('eval test MSE: {0:06f} IOU: {1:06f} prec: {2:06f} recall: {3:06f}'.format(*test_errors))
         | 
| 147 | 
            +
                            MSE, IOU, prec, recall = test_errors
         | 
| 148 | 
            +
                            test_losses['MSE(test)'] = MSE
         | 
| 149 | 
            +
                            test_losses['IOU(test)'] = IOU
         | 
| 150 | 
            +
                            test_losses['prec(test)'] = prec
         | 
| 151 | 
            +
                            test_losses['recall(test)'] = recall
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                            print('calc error (train) ...')
         | 
| 154 | 
            +
                            train_dataset.is_train = False
         | 
| 155 | 
            +
                            train_errors = calc_error(opt, netG, cuda, train_dataset, 100)
         | 
| 156 | 
            +
                            train_dataset.is_train = True
         | 
| 157 | 
            +
                            print('eval train MSE: {0:06f} IOU: {1:06f} prec: {2:06f} recall: {3:06f}'.format(*train_errors))
         | 
| 158 | 
            +
                            MSE, IOU, prec, recall = train_errors
         | 
| 159 | 
            +
                            test_losses['MSE(train)'] = MSE
         | 
| 160 | 
            +
                            test_losses['IOU(train)'] = IOU
         | 
| 161 | 
            +
                            test_losses['prec(train)'] = prec
         | 
| 162 | 
            +
                            test_losses['recall(train)'] = recall
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                        if not opt.no_gen_mesh:
         | 
| 165 | 
            +
                            print('generate mesh (test) ...')
         | 
| 166 | 
            +
                            for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
         | 
| 167 | 
            +
                                test_data = random.choice(test_dataset)
         | 
| 168 | 
            +
                                save_path = '%s/%s/test_eval_epoch%d_%s.obj' % (
         | 
| 169 | 
            +
                                    opt.results_path, opt.name, epoch, test_data['name'])
         | 
| 170 | 
            +
                                gen_mesh(opt, netG, cuda, test_data, save_path)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                            print('generate mesh (train) ...')
         | 
| 173 | 
            +
                            train_dataset.is_train = False
         | 
| 174 | 
            +
                            for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
         | 
| 175 | 
            +
                                train_data = random.choice(train_dataset)
         | 
| 176 | 
            +
                                save_path = '%s/%s/train_eval_epoch%d_%s.obj' % (
         | 
| 177 | 
            +
                                    opt.results_path, opt.name, epoch, train_data['name'])
         | 
| 178 | 
            +
                                gen_mesh(opt, netG, cuda, train_data, save_path)
         | 
| 179 | 
            +
                            train_dataset.is_train = True
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            if __name__ == '__main__':
         | 
| 183 | 
            +
                train(opt)
         | 
    	
        PIFu/env_sh.npy
    ADDED
    
    | Binary file (52 kB). View file | 
|  | 
    	
        PIFu/environment.yml
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            name: PIFu
         | 
| 2 | 
            +
            channels:
         | 
| 3 | 
            +
            - pytorch
         | 
| 4 | 
            +
            - defaults
         | 
| 5 | 
            +
            dependencies:
         | 
| 6 | 
            +
            - opencv
         | 
| 7 | 
            +
            - pytorch
         | 
| 8 | 
            +
            - json
         | 
| 9 | 
            +
            - pyexr
         | 
| 10 | 
            +
            - cv2
         | 
| 11 | 
            +
            - PIL
         | 
| 12 | 
            +
            - skimage
         | 
| 13 | 
            +
            - tqdm
         | 
| 14 | 
            +
            - pyembree
         | 
| 15 | 
            +
            - shapely
         | 
| 16 | 
            +
            - rtree
         | 
| 17 | 
            +
            - xxhash
         | 
| 18 | 
            +
            - trimesh
         | 
| 19 | 
            +
            - PyOpenGL
         | 
    	
        PIFu/inputs/.gitignore
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *
         | 
| 2 | 
            +
            !.gitignore
         | 
    	
        PIFu/lib/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        PIFu/lib/colab_util.py
    ADDED
    
    | @@ -0,0 +1,114 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import io
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from skimage.io import imread
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import cv2
         | 
| 7 | 
            +
            from tqdm import tqdm_notebook as tqdm
         | 
| 8 | 
            +
            import base64
         | 
| 9 | 
            +
            from IPython.display import HTML
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Util function for loading meshes
         | 
| 12 | 
            +
            from pytorch3d.io import load_objs_as_meshes
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from IPython.display import HTML
         | 
| 15 | 
            +
            from base64 import b64encode
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # Data structures and functions for rendering
         | 
| 18 | 
            +
            from pytorch3d.structures import Meshes
         | 
| 19 | 
            +
            from pytorch3d.renderer import (
         | 
| 20 | 
            +
                look_at_view_transform,
         | 
| 21 | 
            +
                OpenGLOrthographicCameras, 
         | 
| 22 | 
            +
                PointLights, 
         | 
| 23 | 
            +
                DirectionalLights, 
         | 
| 24 | 
            +
                Materials, 
         | 
| 25 | 
            +
                RasterizationSettings, 
         | 
| 26 | 
            +
                MeshRenderer, 
         | 
| 27 | 
            +
                MeshRasterizer,  
         | 
| 28 | 
            +
                SoftPhongShader,
         | 
| 29 | 
            +
                HardPhongShader,
         | 
| 30 | 
            +
                TexturesVertex
         | 
| 31 | 
            +
            )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            def set_renderer():
         | 
| 34 | 
            +
                # Setup
         | 
| 35 | 
            +
                device = torch.device("cuda:0")
         | 
| 36 | 
            +
                torch.cuda.set_device(device)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                # Initialize an OpenGL perspective camera.
         | 
| 39 | 
            +
                R, T = look_at_view_transform(2.0, 0, 180) 
         | 
| 40 | 
            +
                cameras = OpenGLOrthographicCameras(device=device, R=R, T=T)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                raster_settings = RasterizationSettings(
         | 
| 43 | 
            +
                    image_size=512, 
         | 
| 44 | 
            +
                    blur_radius=0.0, 
         | 
| 45 | 
            +
                    faces_per_pixel=1, 
         | 
| 46 | 
            +
                    bin_size = None, 
         | 
| 47 | 
            +
                    max_faces_per_bin = None
         | 
| 48 | 
            +
                )
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                lights = PointLights(device=device, location=((2.0, 2.0, 2.0),))
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                renderer = MeshRenderer(
         | 
| 53 | 
            +
                    rasterizer=MeshRasterizer(
         | 
| 54 | 
            +
                        cameras=cameras, 
         | 
| 55 | 
            +
                        raster_settings=raster_settings
         | 
| 56 | 
            +
                    ),
         | 
| 57 | 
            +
                    shader=HardPhongShader(
         | 
| 58 | 
            +
                        device=device, 
         | 
| 59 | 
            +
                        cameras=cameras,
         | 
| 60 | 
            +
                        lights=lights
         | 
| 61 | 
            +
                    )
         | 
| 62 | 
            +
                )
         | 
| 63 | 
            +
                return renderer
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            def get_verts_rgb_colors(obj_path):
         | 
| 66 | 
            +
              rgb_colors = []
         | 
| 67 | 
            +
             | 
| 68 | 
            +
              f = open(obj_path)
         | 
| 69 | 
            +
              lines = f.readlines()
         | 
| 70 | 
            +
              for line in lines:
         | 
| 71 | 
            +
                ls = line.split(' ')
         | 
| 72 | 
            +
                if len(ls) == 7:
         | 
| 73 | 
            +
                  rgb_colors.append(ls[-3:])
         | 
| 74 | 
            +
             | 
| 75 | 
            +
              return np.array(rgb_colors, dtype='float32')[None, :, :]
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            def generate_video_from_obj(obj_path, video_path, renderer):
         | 
| 78 | 
            +
                # Setup
         | 
| 79 | 
            +
                device = torch.device("cuda:0")
         | 
| 80 | 
            +
                torch.cuda.set_device(device)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                # Load obj file
         | 
| 83 | 
            +
                verts_rgb_colors = get_verts_rgb_colors(obj_path)
         | 
| 84 | 
            +
                verts_rgb_colors = torch.from_numpy(verts_rgb_colors).to(device)
         | 
| 85 | 
            +
                textures = TexturesVertex(verts_features=verts_rgb_colors)
         | 
| 86 | 
            +
                wo_textures = TexturesVertex(verts_features=torch.ones_like(verts_rgb_colors)*0.75)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                # Load obj
         | 
| 89 | 
            +
                mesh = load_objs_as_meshes([obj_path], device=device)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                # Set mesh
         | 
| 92 | 
            +
                vers = mesh._verts_list
         | 
| 93 | 
            +
                faces = mesh._faces_list
         | 
| 94 | 
            +
                mesh_w_tex = Meshes(vers, faces, textures)
         | 
| 95 | 
            +
                mesh_wo_tex = Meshes(vers, faces, wo_textures)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                # create VideoWriter
         | 
| 98 | 
            +
                fourcc = cv2. VideoWriter_fourcc(*'MP4V')
         | 
| 99 | 
            +
                out = cv2.VideoWriter(video_path, fourcc, 20.0, (1024,512))
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                for i in tqdm(range(90)):
         | 
| 102 | 
            +
                    R, T = look_at_view_transform(1.8, 0, i*4, device=device)
         | 
| 103 | 
            +
                    images_w_tex = renderer(mesh_w_tex, R=R, T=T)
         | 
| 104 | 
            +
                    images_w_tex = np.clip(images_w_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255
         | 
| 105 | 
            +
                    images_wo_tex = renderer(mesh_wo_tex, R=R, T=T)
         | 
| 106 | 
            +
                    images_wo_tex = np.clip(images_wo_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255
         | 
| 107 | 
            +
                    image = np.concatenate([images_w_tex, images_wo_tex], axis=1)
         | 
| 108 | 
            +
                    out.write(image.astype('uint8'))
         | 
| 109 | 
            +
                out.release()
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            def video(path):
         | 
| 112 | 
            +
                mp4 = open(path,'rb').read()
         | 
| 113 | 
            +
                data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
         | 
| 114 | 
            +
                return HTML('<video width=500 controls loop> <source src="%s" type="video/mp4"></video>' % data_url)
         | 
    	
        PIFu/lib/data/BaseDataset.py
    ADDED
    
    | @@ -0,0 +1,46 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from torch.utils.data import Dataset
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class BaseDataset(Dataset):
         | 
| 6 | 
            +
                '''
         | 
| 7 | 
            +
                This is the Base Datasets.
         | 
| 8 | 
            +
                Itself does nothing and is not runnable.
         | 
| 9 | 
            +
                Check self.get_item function to see what it should return.
         | 
| 10 | 
            +
                '''
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                @staticmethod
         | 
| 13 | 
            +
                def modify_commandline_options(parser, is_train):
         | 
| 14 | 
            +
                    return parser
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __init__(self, opt, phase='train'):
         | 
| 17 | 
            +
                    self.opt = opt
         | 
| 18 | 
            +
                    self.is_train = self.phase == 'train'
         | 
| 19 | 
            +
                    self.projection_mode = 'orthogonal'  # Declare projection mode here
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __len__(self):
         | 
| 22 | 
            +
                    return 0
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def get_item(self, index):
         | 
| 25 | 
            +
                    # In case of a missing file or IO error, switch to a random sample instead
         | 
| 26 | 
            +
                    try:
         | 
| 27 | 
            +
                        res = {
         | 
| 28 | 
            +
                            'name': None,  # name of this subject
         | 
| 29 | 
            +
                            'b_min': None,  # Bounding box (x_min, y_min, z_min) of target space
         | 
| 30 | 
            +
                            'b_max': None,  # Bounding box (x_max, y_max, z_max) of target space
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                            'samples': None,  # [3, N] samples
         | 
| 33 | 
            +
                            'labels': None,  # [1, N] labels
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                            'img': None,  # [num_views, C, H, W] input images
         | 
| 36 | 
            +
                            'calib': None,  # [num_views, 4, 4] calibration matrix
         | 
| 37 | 
            +
                            'extrinsic': None,  # [num_views, 4, 4] extrinsic matrix
         | 
| 38 | 
            +
                            'mask': None,  # [num_views, 1, H, W] segmentation masks
         | 
| 39 | 
            +
                        }
         | 
| 40 | 
            +
                        return res
         | 
| 41 | 
            +
                    except:
         | 
| 42 | 
            +
                        print("Requested index %s has missing files. Using a random sample instead." % index)
         | 
| 43 | 
            +
                        return self.get_item(index=random.randint(0, self.__len__() - 1))
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def __getitem__(self, index):
         | 
| 46 | 
            +
                    return self.get_item(index)
         | 
    	
        PIFu/lib/data/EvalDataset.py
    ADDED
    
    | @@ -0,0 +1,166 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from torch.utils.data import Dataset
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            import torchvision.transforms as transforms
         | 
| 6 | 
            +
            from PIL import Image, ImageOps
         | 
| 7 | 
            +
            import cv2
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            from PIL.ImageFilter import GaussianBlur
         | 
| 10 | 
            +
            import trimesh
         | 
| 11 | 
            +
            import cv2
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class EvalDataset(Dataset):
         | 
| 15 | 
            +
                @staticmethod
         | 
| 16 | 
            +
                def modify_commandline_options(parser):
         | 
| 17 | 
            +
                    return parser
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def __init__(self, opt, root=None):
         | 
| 20 | 
            +
                    self.opt = opt
         | 
| 21 | 
            +
                    self.projection_mode = 'orthogonal'
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    # Path setup
         | 
| 24 | 
            +
                    self.root = self.opt.dataroot
         | 
| 25 | 
            +
                    if root is not None:
         | 
| 26 | 
            +
                        self.root = root
         | 
| 27 | 
            +
                    self.RENDER = os.path.join(self.root, 'RENDER')
         | 
| 28 | 
            +
                    self.MASK = os.path.join(self.root, 'MASK')
         | 
| 29 | 
            +
                    self.PARAM = os.path.join(self.root, 'PARAM')
         | 
| 30 | 
            +
                    self.OBJ = os.path.join(self.root, 'GEO', 'OBJ')
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    self.phase = 'val'
         | 
| 33 | 
            +
                    self.load_size = self.opt.loadSize
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    self.num_views = self.opt.num_views
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    self.max_view_angle = 360
         | 
| 38 | 
            +
                    self.interval = 1
         | 
| 39 | 
            +
                    self.subjects = self.get_subjects()
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    # PIL to tensor
         | 
| 42 | 
            +
                    self.to_tensor = transforms.Compose([
         | 
| 43 | 
            +
                        transforms.Resize(self.load_size),
         | 
| 44 | 
            +
                        transforms.ToTensor(),
         | 
| 45 | 
            +
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         | 
| 46 | 
            +
                    ])
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def get_subjects(self):
         | 
| 49 | 
            +
                    var_file = os.path.join(self.root, 'val.txt')
         | 
| 50 | 
            +
                    if os.path.exists(var_file):
         | 
| 51 | 
            +
                        var_subjects = np.loadtxt(var_file, dtype=str)
         | 
| 52 | 
            +
                        return sorted(list(var_subjects))
         | 
| 53 | 
            +
                    all_subjects = os.listdir(self.RENDER)
         | 
| 54 | 
            +
                    return sorted(list(all_subjects))
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def __len__(self):
         | 
| 57 | 
            +
                    return len(self.subjects) * self.max_view_angle // self.interval
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def get_render(self, subject, num_views, view_id=None, random_sample=False):
         | 
| 60 | 
            +
                    '''
         | 
| 61 | 
            +
                    Return the render data
         | 
| 62 | 
            +
                    :param subject: subject name
         | 
| 63 | 
            +
                    :param num_views: how many views to return
         | 
| 64 | 
            +
                    :param view_id: the first view_id. If None, select a random one.
         | 
| 65 | 
            +
                    :return:
         | 
| 66 | 
            +
                        'img': [num_views, C, W, H] images
         | 
| 67 | 
            +
                        'calib': [num_views, 4, 4] calibration matrix
         | 
| 68 | 
            +
                        'extrinsic': [num_views, 4, 4] extrinsic matrix
         | 
| 69 | 
            +
                        'mask': [num_views, 1, W, H] masks
         | 
| 70 | 
            +
                    '''
         | 
| 71 | 
            +
                    # For now we only have pitch = 00. Hard code it here
         | 
| 72 | 
            +
                    pitch = 0
         | 
| 73 | 
            +
                    # Select a random view_id from self.max_view_angle if not given
         | 
| 74 | 
            +
                    if view_id is None:
         | 
| 75 | 
            +
                        view_id = np.random.randint(self.max_view_angle)
         | 
| 76 | 
            +
                    # The ids are an even distribution of num_views around view_id
         | 
| 77 | 
            +
                    view_ids = [(view_id + self.max_view_angle // num_views * offset) % self.max_view_angle
         | 
| 78 | 
            +
                                for offset in range(num_views)]
         | 
| 79 | 
            +
                    if random_sample:
         | 
| 80 | 
            +
                        view_ids = np.random.choice(self.max_view_angle, num_views, replace=False)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    calib_list = []
         | 
| 83 | 
            +
                    render_list = []
         | 
| 84 | 
            +
                    mask_list = []
         | 
| 85 | 
            +
                    extrinsic_list = []
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    for vid in view_ids:
         | 
| 88 | 
            +
                        param_path = os.path.join(self.PARAM, subject, '%d_%02d.npy' % (vid, pitch))
         | 
| 89 | 
            +
                        render_path = os.path.join(self.RENDER, subject, '%d_%02d.jpg' % (vid, pitch))
         | 
| 90 | 
            +
                        mask_path = os.path.join(self.MASK, subject, '%d_%02d.png' % (vid, pitch))
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                        # loading calibration data
         | 
| 93 | 
            +
                        param = np.load(param_path)
         | 
| 94 | 
            +
                        # pixel unit / world unit
         | 
| 95 | 
            +
                        ortho_ratio = param.item().get('ortho_ratio')
         | 
| 96 | 
            +
                        # world unit / model unit
         | 
| 97 | 
            +
                        scale = param.item().get('scale')
         | 
| 98 | 
            +
                        # camera center world coordinate
         | 
| 99 | 
            +
                        center = param.item().get('center')
         | 
| 100 | 
            +
                        # model rotation
         | 
| 101 | 
            +
                        R = param.item().get('R')
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                        translate = -np.matmul(R, center).reshape(3, 1)
         | 
| 104 | 
            +
                        extrinsic = np.concatenate([R, translate], axis=1)
         | 
| 105 | 
            +
                        extrinsic = np.concatenate([extrinsic, np.array([0, 0, 0, 1]).reshape(1, 4)], 0)
         | 
| 106 | 
            +
                        # Match camera space to image pixel space
         | 
| 107 | 
            +
                        scale_intrinsic = np.identity(4)
         | 
| 108 | 
            +
                        scale_intrinsic[0, 0] = scale / ortho_ratio
         | 
| 109 | 
            +
                        scale_intrinsic[1, 1] = -scale / ortho_ratio
         | 
| 110 | 
            +
                        scale_intrinsic[2, 2] = -scale / ortho_ratio
         | 
| 111 | 
            +
                        # Match image pixel space to image uv space
         | 
| 112 | 
            +
                        uv_intrinsic = np.identity(4)
         | 
| 113 | 
            +
                        uv_intrinsic[0, 0] = 1.0 / float(self.opt.loadSize // 2)
         | 
| 114 | 
            +
                        uv_intrinsic[1, 1] = 1.0 / float(self.opt.loadSize // 2)
         | 
| 115 | 
            +
                        uv_intrinsic[2, 2] = 1.0 / float(self.opt.loadSize // 2)
         | 
| 116 | 
            +
                        # Transform under image pixel space
         | 
| 117 | 
            +
                        trans_intrinsic = np.identity(4)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        mask = Image.open(mask_path).convert('L')
         | 
| 120 | 
            +
                        render = Image.open(render_path).convert('RGB')
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                        intrinsic = np.matmul(trans_intrinsic, np.matmul(uv_intrinsic, scale_intrinsic))
         | 
| 123 | 
            +
                        calib = torch.Tensor(np.matmul(intrinsic, extrinsic)).float()
         | 
| 124 | 
            +
                        extrinsic = torch.Tensor(extrinsic).float()
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        mask = transforms.Resize(self.load_size)(mask)
         | 
| 127 | 
            +
                        mask = transforms.ToTensor()(mask).float()
         | 
| 128 | 
            +
                        mask_list.append(mask)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        render = self.to_tensor(render)
         | 
| 131 | 
            +
                        render = mask.expand_as(render) * render
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                        render_list.append(render)
         | 
| 134 | 
            +
                        calib_list.append(calib)
         | 
| 135 | 
            +
                        extrinsic_list.append(extrinsic)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    return {
         | 
| 138 | 
            +
                        'img': torch.stack(render_list, dim=0),
         | 
| 139 | 
            +
                        'calib': torch.stack(calib_list, dim=0),
         | 
| 140 | 
            +
                        'extrinsic': torch.stack(extrinsic_list, dim=0),
         | 
| 141 | 
            +
                        'mask': torch.stack(mask_list, dim=0)
         | 
| 142 | 
            +
                    }
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def get_item(self, index):
         | 
| 145 | 
            +
                    # In case of a missing file or IO error, switch to a random sample instead
         | 
| 146 | 
            +
                    try:
         | 
| 147 | 
            +
                        sid = index % len(self.subjects)
         | 
| 148 | 
            +
                        vid = (index // len(self.subjects)) * self.interval
         | 
| 149 | 
            +
                        # name of the subject 'rp_xxxx_xxx'
         | 
| 150 | 
            +
                        subject = self.subjects[sid]
         | 
| 151 | 
            +
                        res = {
         | 
| 152 | 
            +
                            'name': subject,
         | 
| 153 | 
            +
                            'mesh_path': os.path.join(self.OBJ, subject + '.obj'),
         | 
| 154 | 
            +
                            'sid': sid,
         | 
| 155 | 
            +
                            'vid': vid,
         | 
| 156 | 
            +
                        }
         | 
| 157 | 
            +
                        render_data = self.get_render(subject, num_views=self.num_views, view_id=vid,
         | 
| 158 | 
            +
                                                      random_sample=self.opt.random_multiview)
         | 
| 159 | 
            +
                        res.update(render_data)
         | 
| 160 | 
            +
                        return res
         | 
| 161 | 
            +
                    except Exception as e:
         | 
| 162 | 
            +
                        print(e)
         | 
| 163 | 
            +
                        return self.get_item(index=random.randint(0, self.__len__() - 1))
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                def __getitem__(self, index):
         | 
| 166 | 
            +
                    return self.get_item(index)
         | 
    	
        PIFu/lib/data/TrainDataset.py
    ADDED
    
    | @@ -0,0 +1,390 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from torch.utils.data import Dataset
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            import torchvision.transforms as transforms
         | 
| 6 | 
            +
            from PIL import Image, ImageOps
         | 
| 7 | 
            +
            import cv2
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            from PIL.ImageFilter import GaussianBlur
         | 
| 10 | 
            +
            import trimesh
         | 
| 11 | 
            +
            import logging
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            log = logging.getLogger('trimesh')
         | 
| 14 | 
            +
            log.setLevel(40)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def load_trimesh(root_dir):
         | 
| 17 | 
            +
                folders = os.listdir(root_dir)
         | 
| 18 | 
            +
                meshs = {}
         | 
| 19 | 
            +
                for i, f in enumerate(folders):
         | 
| 20 | 
            +
                    sub_name = f
         | 
| 21 | 
            +
                    meshs[sub_name] = trimesh.load(os.path.join(root_dir, f, '%s_100k.obj' % sub_name))
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                return meshs
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            def save_samples_truncted_prob(fname, points, prob):
         | 
| 26 | 
            +
                '''
         | 
| 27 | 
            +
                Save the visualization of sampling to a ply file.
         | 
| 28 | 
            +
                Red points represent positive predictions.
         | 
| 29 | 
            +
                Green points represent negative predictions.
         | 
| 30 | 
            +
                :param fname: File name to save
         | 
| 31 | 
            +
                :param points: [N, 3] array of points
         | 
| 32 | 
            +
                :param prob: [N, 1] array of predictions in the range [0~1]
         | 
| 33 | 
            +
                :return:
         | 
| 34 | 
            +
                '''
         | 
| 35 | 
            +
                r = (prob > 0.5).reshape([-1, 1]) * 255
         | 
| 36 | 
            +
                g = (prob < 0.5).reshape([-1, 1]) * 255
         | 
| 37 | 
            +
                b = np.zeros(r.shape)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                to_save = np.concatenate([points, r, g, b], axis=-1)
         | 
| 40 | 
            +
                return np.savetxt(fname,
         | 
| 41 | 
            +
                                  to_save,
         | 
| 42 | 
            +
                                  fmt='%.6f %.6f %.6f %d %d %d',
         | 
| 43 | 
            +
                                  comments='',
         | 
| 44 | 
            +
                                  header=(
         | 
| 45 | 
            +
                                      'ply\nformat ascii 1.0\nelement vertex {:d}\nproperty float x\nproperty float y\nproperty float z\nproperty uchar red\nproperty uchar green\nproperty uchar blue\nend_header').format(
         | 
| 46 | 
            +
                                      points.shape[0])
         | 
| 47 | 
            +
                                  )
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            class TrainDataset(Dataset):
         | 
| 51 | 
            +
                @staticmethod
         | 
| 52 | 
            +
                def modify_commandline_options(parser, is_train):
         | 
| 53 | 
            +
                    return parser
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __init__(self, opt, phase='train'):
         | 
| 56 | 
            +
                    self.opt = opt
         | 
| 57 | 
            +
                    self.projection_mode = 'orthogonal'
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # Path setup
         | 
| 60 | 
            +
                    self.root = self.opt.dataroot
         | 
| 61 | 
            +
                    self.RENDER = os.path.join(self.root, 'RENDER')
         | 
| 62 | 
            +
                    self.MASK = os.path.join(self.root, 'MASK')
         | 
| 63 | 
            +
                    self.PARAM = os.path.join(self.root, 'PARAM')
         | 
| 64 | 
            +
                    self.UV_MASK = os.path.join(self.root, 'UV_MASK')
         | 
| 65 | 
            +
                    self.UV_NORMAL = os.path.join(self.root, 'UV_NORMAL')
         | 
| 66 | 
            +
                    self.UV_RENDER = os.path.join(self.root, 'UV_RENDER')
         | 
| 67 | 
            +
                    self.UV_POS = os.path.join(self.root, 'UV_POS')
         | 
| 68 | 
            +
                    self.OBJ = os.path.join(self.root, 'GEO', 'OBJ')
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    self.B_MIN = np.array([-128, -28, -128])
         | 
| 71 | 
            +
                    self.B_MAX = np.array([128, 228, 128])
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    self.is_train = (phase == 'train')
         | 
| 74 | 
            +
                    self.load_size = self.opt.loadSize
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self.num_views = self.opt.num_views
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    self.num_sample_inout = self.opt.num_sample_inout
         | 
| 79 | 
            +
                    self.num_sample_color = self.opt.num_sample_color
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    self.yaw_list = list(range(0,360,1))
         | 
| 82 | 
            +
                    self.pitch_list = [0]
         | 
| 83 | 
            +
                    self.subjects = self.get_subjects()
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    # PIL to tensor
         | 
| 86 | 
            +
                    self.to_tensor = transforms.Compose([
         | 
| 87 | 
            +
                        transforms.Resize(self.load_size),
         | 
| 88 | 
            +
                        transforms.ToTensor(),
         | 
| 89 | 
            +
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         | 
| 90 | 
            +
                    ])
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    # augmentation
         | 
| 93 | 
            +
                    self.aug_trans = transforms.Compose([
         | 
| 94 | 
            +
                        transforms.ColorJitter(brightness=opt.aug_bri, contrast=opt.aug_con, saturation=opt.aug_sat,
         | 
| 95 | 
            +
                                               hue=opt.aug_hue)
         | 
| 96 | 
            +
                    ])
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    self.mesh_dic = load_trimesh(self.OBJ)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def get_subjects(self):
         | 
| 101 | 
            +
                    all_subjects = os.listdir(self.RENDER)
         | 
| 102 | 
            +
                    var_subjects = np.loadtxt(os.path.join(self.root, 'val.txt'), dtype=str)
         | 
| 103 | 
            +
                    if len(var_subjects) == 0:
         | 
| 104 | 
            +
                        return all_subjects
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    if self.is_train:
         | 
| 107 | 
            +
                        return sorted(list(set(all_subjects) - set(var_subjects)))
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        return sorted(list(var_subjects))
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def __len__(self):
         | 
| 112 | 
            +
                    return len(self.subjects) * len(self.yaw_list) * len(self.pitch_list)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def get_render(self, subject, num_views, yid=0, pid=0, random_sample=False):
         | 
| 115 | 
            +
                    '''
         | 
| 116 | 
            +
                    Return the render data
         | 
| 117 | 
            +
                    :param subject: subject name
         | 
| 118 | 
            +
                    :param num_views: how many views to return
         | 
| 119 | 
            +
                    :param view_id: the first view_id. If None, select a random one.
         | 
| 120 | 
            +
                    :return:
         | 
| 121 | 
            +
                        'img': [num_views, C, W, H] images
         | 
| 122 | 
            +
                        'calib': [num_views, 4, 4] calibration matrix
         | 
| 123 | 
            +
                        'extrinsic': [num_views, 4, 4] extrinsic matrix
         | 
| 124 | 
            +
                        'mask': [num_views, 1, W, H] masks
         | 
| 125 | 
            +
                    '''
         | 
| 126 | 
            +
                    pitch = self.pitch_list[pid]
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    # The ids are an even distribution of num_views around view_id
         | 
| 129 | 
            +
                    view_ids = [self.yaw_list[(yid + len(self.yaw_list) // num_views * offset) % len(self.yaw_list)]
         | 
| 130 | 
            +
                                for offset in range(num_views)]
         | 
| 131 | 
            +
                    if random_sample:
         | 
| 132 | 
            +
                        view_ids = np.random.choice(self.yaw_list, num_views, replace=False)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    calib_list = []
         | 
| 135 | 
            +
                    render_list = []
         | 
| 136 | 
            +
                    mask_list = []
         | 
| 137 | 
            +
                    extrinsic_list = []
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    for vid in view_ids:
         | 
| 140 | 
            +
                        param_path = os.path.join(self.PARAM, subject, '%d_%d_%02d.npy' % (vid, pitch, 0))
         | 
| 141 | 
            +
                        render_path = os.path.join(self.RENDER, subject, '%d_%d_%02d.jpg' % (vid, pitch, 0))
         | 
| 142 | 
            +
                        mask_path = os.path.join(self.MASK, subject, '%d_%d_%02d.png' % (vid, pitch, 0))
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                        # loading calibration data
         | 
| 145 | 
            +
                        param = np.load(param_path, allow_pickle=True)
         | 
| 146 | 
            +
                        # pixel unit / world unit
         | 
| 147 | 
            +
                        ortho_ratio = param.item().get('ortho_ratio')
         | 
| 148 | 
            +
                        # world unit / model unit
         | 
| 149 | 
            +
                        scale = param.item().get('scale')
         | 
| 150 | 
            +
                        # camera center world coordinate
         | 
| 151 | 
            +
                        center = param.item().get('center')
         | 
| 152 | 
            +
                        # model rotation
         | 
| 153 | 
            +
                        R = param.item().get('R')
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                        translate = -np.matmul(R, center).reshape(3, 1)
         | 
| 156 | 
            +
                        extrinsic = np.concatenate([R, translate], axis=1)
         | 
| 157 | 
            +
                        extrinsic = np.concatenate([extrinsic, np.array([0, 0, 0, 1]).reshape(1, 4)], 0)
         | 
| 158 | 
            +
                        # Match camera space to image pixel space
         | 
| 159 | 
            +
                        scale_intrinsic = np.identity(4)
         | 
| 160 | 
            +
                        scale_intrinsic[0, 0] = scale / ortho_ratio
         | 
| 161 | 
            +
                        scale_intrinsic[1, 1] = -scale / ortho_ratio
         | 
| 162 | 
            +
                        scale_intrinsic[2, 2] = scale / ortho_ratio
         | 
| 163 | 
            +
                        # Match image pixel space to image uv space
         | 
| 164 | 
            +
                        uv_intrinsic = np.identity(4)
         | 
| 165 | 
            +
                        uv_intrinsic[0, 0] = 1.0 / float(self.opt.loadSize // 2)
         | 
| 166 | 
            +
                        uv_intrinsic[1, 1] = 1.0 / float(self.opt.loadSize // 2)
         | 
| 167 | 
            +
                        uv_intrinsic[2, 2] = 1.0 / float(self.opt.loadSize // 2)
         | 
| 168 | 
            +
                        # Transform under image pixel space
         | 
| 169 | 
            +
                        trans_intrinsic = np.identity(4)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                        mask = Image.open(mask_path).convert('L')
         | 
| 172 | 
            +
                        render = Image.open(render_path).convert('RGB')
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                        if self.is_train:
         | 
| 175 | 
            +
                            # Pad images
         | 
| 176 | 
            +
                            pad_size = int(0.1 * self.load_size)
         | 
| 177 | 
            +
                            render = ImageOps.expand(render, pad_size, fill=0)
         | 
| 178 | 
            +
                            mask = ImageOps.expand(mask, pad_size, fill=0)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                            w, h = render.size
         | 
| 181 | 
            +
                            th, tw = self.load_size, self.load_size
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                            # random flip
         | 
| 184 | 
            +
                            if self.opt.random_flip and np.random.rand() > 0.5:
         | 
| 185 | 
            +
                                scale_intrinsic[0, 0] *= -1
         | 
| 186 | 
            +
                                render = transforms.RandomHorizontalFlip(p=1.0)(render)
         | 
| 187 | 
            +
                                mask = transforms.RandomHorizontalFlip(p=1.0)(mask)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                            # random scale
         | 
| 190 | 
            +
                            if self.opt.random_scale:
         | 
| 191 | 
            +
                                rand_scale = random.uniform(0.9, 1.1)
         | 
| 192 | 
            +
                                w = int(rand_scale * w)
         | 
| 193 | 
            +
                                h = int(rand_scale * h)
         | 
| 194 | 
            +
                                render = render.resize((w, h), Image.BILINEAR)
         | 
| 195 | 
            +
                                mask = mask.resize((w, h), Image.NEAREST)
         | 
| 196 | 
            +
                                scale_intrinsic *= rand_scale
         | 
| 197 | 
            +
                                scale_intrinsic[3, 3] = 1
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                            # random translate in the pixel space
         | 
| 200 | 
            +
                            if self.opt.random_trans:
         | 
| 201 | 
            +
                                dx = random.randint(-int(round((w - tw) / 10.)),
         | 
| 202 | 
            +
                                                    int(round((w - tw) / 10.)))
         | 
| 203 | 
            +
                                dy = random.randint(-int(round((h - th) / 10.)),
         | 
| 204 | 
            +
                                                    int(round((h - th) / 10.)))
         | 
| 205 | 
            +
                            else:
         | 
| 206 | 
            +
                                dx = 0
         | 
| 207 | 
            +
                                dy = 0
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                            trans_intrinsic[0, 3] = -dx / float(self.opt.loadSize // 2)
         | 
| 210 | 
            +
                            trans_intrinsic[1, 3] = -dy / float(self.opt.loadSize // 2)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                            x1 = int(round((w - tw) / 2.)) + dx
         | 
| 213 | 
            +
                            y1 = int(round((h - th) / 2.)) + dy
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                            render = render.crop((x1, y1, x1 + tw, y1 + th))
         | 
| 216 | 
            +
                            mask = mask.crop((x1, y1, x1 + tw, y1 + th))
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                            render = self.aug_trans(render)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                            # random blur
         | 
| 221 | 
            +
                            if self.opt.aug_blur > 0.00001:
         | 
| 222 | 
            +
                                blur = GaussianBlur(np.random.uniform(0, self.opt.aug_blur))
         | 
| 223 | 
            +
                                render = render.filter(blur)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                        intrinsic = np.matmul(trans_intrinsic, np.matmul(uv_intrinsic, scale_intrinsic))
         | 
| 226 | 
            +
                        calib = torch.Tensor(np.matmul(intrinsic, extrinsic)).float()
         | 
| 227 | 
            +
                        extrinsic = torch.Tensor(extrinsic).float()
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                        mask = transforms.Resize(self.load_size)(mask)
         | 
| 230 | 
            +
                        mask = transforms.ToTensor()(mask).float()
         | 
| 231 | 
            +
                        mask_list.append(mask)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                        render = self.to_tensor(render)
         | 
| 234 | 
            +
                        render = mask.expand_as(render) * render
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                        render_list.append(render)
         | 
| 237 | 
            +
                        calib_list.append(calib)
         | 
| 238 | 
            +
                        extrinsic_list.append(extrinsic)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    return {
         | 
| 241 | 
            +
                        'img': torch.stack(render_list, dim=0),
         | 
| 242 | 
            +
                        'calib': torch.stack(calib_list, dim=0),
         | 
| 243 | 
            +
                        'extrinsic': torch.stack(extrinsic_list, dim=0),
         | 
| 244 | 
            +
                        'mask': torch.stack(mask_list, dim=0)
         | 
| 245 | 
            +
                    }
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                def select_sampling_method(self, subject):
         | 
| 248 | 
            +
                    if not self.is_train:
         | 
| 249 | 
            +
                        random.seed(1991)
         | 
| 250 | 
            +
                        np.random.seed(1991)
         | 
| 251 | 
            +
                        torch.manual_seed(1991)
         | 
| 252 | 
            +
                    mesh = self.mesh_dic[subject]
         | 
| 253 | 
            +
                    surface_points, _ = trimesh.sample.sample_surface(mesh, 4 * self.num_sample_inout)
         | 
| 254 | 
            +
                    sample_points = surface_points + np.random.normal(scale=self.opt.sigma, size=surface_points.shape)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    # add random points within image space
         | 
| 257 | 
            +
                    length = self.B_MAX - self.B_MIN
         | 
| 258 | 
            +
                    random_points = np.random.rand(self.num_sample_inout // 4, 3) * length + self.B_MIN
         | 
| 259 | 
            +
                    sample_points = np.concatenate([sample_points, random_points], 0)
         | 
| 260 | 
            +
                    np.random.shuffle(sample_points)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    inside = mesh.contains(sample_points)
         | 
| 263 | 
            +
                    inside_points = sample_points[inside]
         | 
| 264 | 
            +
                    outside_points = sample_points[np.logical_not(inside)]
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    nin = inside_points.shape[0]
         | 
| 267 | 
            +
                    inside_points = inside_points[
         | 
| 268 | 
            +
                                    :self.num_sample_inout // 2] if nin > self.num_sample_inout // 2 else inside_points
         | 
| 269 | 
            +
                    outside_points = outside_points[
         | 
| 270 | 
            +
                                     :self.num_sample_inout // 2] if nin > self.num_sample_inout // 2 else outside_points[
         | 
| 271 | 
            +
                                                                                                           :(self.num_sample_inout - nin)]
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    samples = np.concatenate([inside_points, outside_points], 0).T
         | 
| 274 | 
            +
                    labels = np.concatenate([np.ones((1, inside_points.shape[0])), np.zeros((1, outside_points.shape[0]))], 1)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    # save_samples_truncted_prob('out.ply', samples.T, labels.T)
         | 
| 277 | 
            +
                    # exit()
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    samples = torch.Tensor(samples).float()
         | 
| 280 | 
            +
                    labels = torch.Tensor(labels).float()
         | 
| 281 | 
            +
                    
         | 
| 282 | 
            +
                    del mesh
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    return {
         | 
| 285 | 
            +
                        'samples': samples,
         | 
| 286 | 
            +
                        'labels': labels
         | 
| 287 | 
            +
                    }
         | 
| 288 | 
            +
             | 
| 289 | 
            +
             | 
| 290 | 
            +
                def get_color_sampling(self, subject, yid, pid=0):
         | 
| 291 | 
            +
                    yaw = self.yaw_list[yid]
         | 
| 292 | 
            +
                    pitch = self.pitch_list[pid]
         | 
| 293 | 
            +
                    uv_render_path = os.path.join(self.UV_RENDER, subject, '%d_%d_%02d.jpg' % (yaw, pitch, 0))
         | 
| 294 | 
            +
                    uv_mask_path = os.path.join(self.UV_MASK, subject, '%02d.png' % (0))
         | 
| 295 | 
            +
                    uv_pos_path = os.path.join(self.UV_POS, subject, '%02d.exr' % (0))
         | 
| 296 | 
            +
                    uv_normal_path = os.path.join(self.UV_NORMAL, subject, '%02d.png' % (0))
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    # Segmentation mask for the uv render.
         | 
| 299 | 
            +
                    # [H, W] bool
         | 
| 300 | 
            +
                    uv_mask = cv2.imread(uv_mask_path)
         | 
| 301 | 
            +
                    uv_mask = uv_mask[:, :, 0] != 0
         | 
| 302 | 
            +
                    # UV render. each pixel is the color of the point.
         | 
| 303 | 
            +
                    # [H, W, 3] 0 ~ 1 float
         | 
| 304 | 
            +
                    uv_render = cv2.imread(uv_render_path)
         | 
| 305 | 
            +
                    uv_render = cv2.cvtColor(uv_render, cv2.COLOR_BGR2RGB) / 255.0
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    # Normal render. each pixel is the surface normal of the point.
         | 
| 308 | 
            +
                    # [H, W, 3] -1 ~ 1 float
         | 
| 309 | 
            +
                    uv_normal = cv2.imread(uv_normal_path)
         | 
| 310 | 
            +
                    uv_normal = cv2.cvtColor(uv_normal, cv2.COLOR_BGR2RGB) / 255.0
         | 
| 311 | 
            +
                    uv_normal = 2.0 * uv_normal - 1.0
         | 
| 312 | 
            +
                    # Position render. each pixel is the xyz coordinates of the point
         | 
| 313 | 
            +
                    uv_pos = cv2.imread(uv_pos_path, 2 | 4)[:, :, ::-1]
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    ### In these few lines we flattern the masks, positions, and normals
         | 
| 316 | 
            +
                    uv_mask = uv_mask.reshape((-1))
         | 
| 317 | 
            +
                    uv_pos = uv_pos.reshape((-1, 3))
         | 
| 318 | 
            +
                    uv_render = uv_render.reshape((-1, 3))
         | 
| 319 | 
            +
                    uv_normal = uv_normal.reshape((-1, 3))
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    surface_points = uv_pos[uv_mask]
         | 
| 322 | 
            +
                    surface_colors = uv_render[uv_mask]
         | 
| 323 | 
            +
                    surface_normal = uv_normal[uv_mask]
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    if self.num_sample_color:
         | 
| 326 | 
            +
                        sample_list = random.sample(range(0, surface_points.shape[0] - 1), self.num_sample_color)
         | 
| 327 | 
            +
                        surface_points = surface_points[sample_list].T
         | 
| 328 | 
            +
                        surface_colors = surface_colors[sample_list].T
         | 
| 329 | 
            +
                        surface_normal = surface_normal[sample_list].T
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    # Samples are around the true surface with an offset
         | 
| 332 | 
            +
                    normal = torch.Tensor(surface_normal).float()
         | 
| 333 | 
            +
                    samples = torch.Tensor(surface_points).float() \
         | 
| 334 | 
            +
                              + torch.normal(mean=torch.zeros((1, normal.size(1))), std=self.opt.sigma).expand_as(normal) * normal
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    # Normalized to [-1, 1]
         | 
| 337 | 
            +
                    rgbs_color = 2.0 * torch.Tensor(surface_colors).float() - 1.0
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                    return {
         | 
| 340 | 
            +
                        'color_samples': samples,
         | 
| 341 | 
            +
                        'rgbs': rgbs_color
         | 
| 342 | 
            +
                    }
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                def get_item(self, index):
         | 
| 345 | 
            +
                    # In case of a missing file or IO error, switch to a random sample instead
         | 
| 346 | 
            +
                    # try:
         | 
| 347 | 
            +
                    sid = index % len(self.subjects)
         | 
| 348 | 
            +
                    tmp = index // len(self.subjects)
         | 
| 349 | 
            +
                    yid = tmp % len(self.yaw_list)
         | 
| 350 | 
            +
                    pid = tmp // len(self.yaw_list)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    # name of the subject 'rp_xxxx_xxx'
         | 
| 353 | 
            +
                    subject = self.subjects[sid]
         | 
| 354 | 
            +
                    res = {
         | 
| 355 | 
            +
                        'name': subject,
         | 
| 356 | 
            +
                        'mesh_path': os.path.join(self.OBJ, subject + '.obj'),
         | 
| 357 | 
            +
                        'sid': sid,
         | 
| 358 | 
            +
                        'yid': yid,
         | 
| 359 | 
            +
                        'pid': pid,
         | 
| 360 | 
            +
                        'b_min': self.B_MIN,
         | 
| 361 | 
            +
                        'b_max': self.B_MAX,
         | 
| 362 | 
            +
                    }
         | 
| 363 | 
            +
                    render_data = self.get_render(subject, num_views=self.num_views, yid=yid, pid=pid,
         | 
| 364 | 
            +
                                                    random_sample=self.opt.random_multiview)
         | 
| 365 | 
            +
                    res.update(render_data)
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    if self.opt.num_sample_inout:
         | 
| 368 | 
            +
                        sample_data = self.select_sampling_method(subject)
         | 
| 369 | 
            +
                        res.update(sample_data)
         | 
| 370 | 
            +
                    
         | 
| 371 | 
            +
                    # img = np.uint8((np.transpose(render_data['img'][0].numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0)
         | 
| 372 | 
            +
                    # rot = render_data['calib'][0,:3, :3]
         | 
| 373 | 
            +
                    # trans = render_data['calib'][0,:3, 3:4]
         | 
| 374 | 
            +
                    # pts = torch.addmm(trans, rot, sample_data['samples'][:, sample_data['labels'][0] > 0.5])  # [3, N]
         | 
| 375 | 
            +
                    # pts = 0.5 * (pts.numpy().T + 1.0) * render_data['img'].size(2)
         | 
| 376 | 
            +
                    # for p in pts:
         | 
| 377 | 
            +
                    #     img = cv2.circle(img, (p[0], p[1]), 2, (0,255,0), -1)
         | 
| 378 | 
            +
                    # cv2.imshow('test', img)
         | 
| 379 | 
            +
                    # cv2.waitKey(1)
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    if self.num_sample_color:
         | 
| 382 | 
            +
                        color_data = self.get_color_sampling(subject, yid=yid, pid=pid)
         | 
| 383 | 
            +
                        res.update(color_data)
         | 
| 384 | 
            +
                    return res
         | 
| 385 | 
            +
                    # except Exception as e:
         | 
| 386 | 
            +
                    #     print(e)
         | 
| 387 | 
            +
                    #     return self.get_item(index=random.randint(0, self.__len__() - 1))
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                def __getitem__(self, index):
         | 
| 390 | 
            +
                    return self.get_item(index)
         | 
    	
        PIFu/lib/data/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .EvalDataset import EvalDataset
         | 
| 2 | 
            +
            from .TrainDataset import TrainDataset
         | 
    	
        PIFu/lib/ext_transform.py
    ADDED
    
    | @@ -0,0 +1,78 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            from skimage.filters import gaussian
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from PIL import Image, ImageFilter
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class RandomVerticalFlip(object):
         | 
| 10 | 
            +
                def __call__(self, img):
         | 
| 11 | 
            +
                    if random.random() < 0.5:
         | 
| 12 | 
            +
                        return img.transpose(Image.FLIP_TOP_BOTTOM)
         | 
| 13 | 
            +
                    return img
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class DeNormalize(object):
         | 
| 17 | 
            +
                def __init__(self, mean, std):
         | 
| 18 | 
            +
                    self.mean = mean
         | 
| 19 | 
            +
                    self.std = std
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __call__(self, tensor):
         | 
| 22 | 
            +
                    for t, m, s in zip(tensor, self.mean, self.std):
         | 
| 23 | 
            +
                        t.mul_(s).add_(m)
         | 
| 24 | 
            +
                    return tensor
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            class MaskToTensor(object):
         | 
| 28 | 
            +
                def __call__(self, img):
         | 
| 29 | 
            +
                    return torch.from_numpy(np.array(img, dtype=np.int32)).long()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class FreeScale(object):
         | 
| 33 | 
            +
                def __init__(self, size, interpolation=Image.BILINEAR):
         | 
| 34 | 
            +
                    self.size = tuple(reversed(size))  # size: (h, w)
         | 
| 35 | 
            +
                    self.interpolation = interpolation
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def __call__(self, img):
         | 
| 38 | 
            +
                    return img.resize(self.size, self.interpolation)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            class FlipChannels(object):
         | 
| 42 | 
            +
                def __call__(self, img):
         | 
| 43 | 
            +
                    img = np.array(img)[:, :, ::-1]
         | 
| 44 | 
            +
                    return Image.fromarray(img.astype(np.uint8))
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class RandomGaussianBlur(object):
         | 
| 48 | 
            +
                def __call__(self, img):
         | 
| 49 | 
            +
                    sigma = 0.15 + random.random() * 1.15
         | 
| 50 | 
            +
                    blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True)
         | 
| 51 | 
            +
                    blurred_img *= 255
         | 
| 52 | 
            +
                    return Image.fromarray(blurred_img.astype(np.uint8))
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            # Lighting data augmentation take from here - https://github.com/eladhoffer/convNet.pytorch/blob/master/preprocess.py
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class Lighting(object):
         | 
| 58 | 
            +
                """Lighting noise(AlexNet - style PCA - based noise)"""
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def __init__(self, alphastd, 
         | 
| 61 | 
            +
                             eigval=(0.2175, 0.0188, 0.0045), 
         | 
| 62 | 
            +
                             eigvec=((-0.5675, 0.7192, 0.4009),
         | 
| 63 | 
            +
                                     (-0.5808, -0.0045, -0.8140),
         | 
| 64 | 
            +
                                     (-0.5836, -0.6948, 0.4203))):
         | 
| 65 | 
            +
                    self.alphastd = alphastd
         | 
| 66 | 
            +
                    self.eigval = torch.Tensor(eigval)
         | 
| 67 | 
            +
                    self.eigvec = torch.Tensor(eigvec)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def __call__(self, img):
         | 
| 70 | 
            +
                    if self.alphastd == 0:
         | 
| 71 | 
            +
                        return img
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    alpha = img.new().resize_(3).normal_(0, self.alphastd)
         | 
| 74 | 
            +
                    rgb = self.eigvec.type_as(img).clone()\
         | 
| 75 | 
            +
                        .mul(alpha.view(1, 3).expand(3, 3))\
         | 
| 76 | 
            +
                        .mul(self.eigval.view(1, 3).expand(3, 3))\
         | 
| 77 | 
            +
                        .sum(1).squeeze()
         | 
| 78 | 
            +
                    return img.add(rgb.view(3, 1, 1).expand_as(img))
         | 
    	
        PIFu/lib/geometry.py
    ADDED
    
    | @@ -0,0 +1,55 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def index(feat, uv):
         | 
| 5 | 
            +
                '''
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                :param feat: [B, C, H, W] image features
         | 
| 8 | 
            +
                :param uv: [B, 2, N] uv coordinates in the image plane, range [-1, 1]
         | 
| 9 | 
            +
                :return: [B, C, N] image features at the uv coordinates
         | 
| 10 | 
            +
                '''
         | 
| 11 | 
            +
                uv = uv.transpose(1, 2)  # [B, N, 2]
         | 
| 12 | 
            +
                uv = uv.unsqueeze(2)  # [B, N, 1, 2]
         | 
| 13 | 
            +
                # NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
         | 
| 14 | 
            +
                # for old versions, simply remove the aligned_corners argument.
         | 
| 15 | 
            +
                samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True)  # [B, C, N, 1]
         | 
| 16 | 
            +
                return samples[:, :, :, 0]  # [B, C, N]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def orthogonal(points, calibrations, transforms=None):
         | 
| 20 | 
            +
                '''
         | 
| 21 | 
            +
                Compute the orthogonal projections of 3D points into the image plane by given projection matrix
         | 
| 22 | 
            +
                :param points: [B, 3, N] Tensor of 3D points
         | 
| 23 | 
            +
                :param calibrations: [B, 4, 4] Tensor of projection matrix
         | 
| 24 | 
            +
                :param transforms: [B, 2, 3] Tensor of image transform matrix
         | 
| 25 | 
            +
                :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
         | 
| 26 | 
            +
                '''
         | 
| 27 | 
            +
                rot = calibrations[:, :3, :3]
         | 
| 28 | 
            +
                trans = calibrations[:, :3, 3:4]
         | 
| 29 | 
            +
                pts = torch.baddbmm(trans, rot, points)  # [B, 3, N]
         | 
| 30 | 
            +
                if transforms is not None:
         | 
| 31 | 
            +
                    scale = transforms[:2, :2]
         | 
| 32 | 
            +
                    shift = transforms[:2, 2:3]
         | 
| 33 | 
            +
                    pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
         | 
| 34 | 
            +
                return pts
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def perspective(points, calibrations, transforms=None):
         | 
| 38 | 
            +
                '''
         | 
| 39 | 
            +
                Compute the perspective projections of 3D points into the image plane by given projection matrix
         | 
| 40 | 
            +
                :param points: [Bx3xN] Tensor of 3D points
         | 
| 41 | 
            +
                :param calibrations: [Bx4x4] Tensor of projection matrix
         | 
| 42 | 
            +
                :param transforms: [Bx2x3] Tensor of image transform matrix
         | 
| 43 | 
            +
                :return: xy: [Bx2xN] Tensor of xy coordinates in the image plane
         | 
| 44 | 
            +
                '''
         | 
| 45 | 
            +
                rot = calibrations[:, :3, :3]
         | 
| 46 | 
            +
                trans = calibrations[:, :3, 3:4]
         | 
| 47 | 
            +
                homo = torch.baddbmm(trans, rot, points)  # [B, 3, N]
         | 
| 48 | 
            +
                xy = homo[:, :2, :] / homo[:, 2:3, :]
         | 
| 49 | 
            +
                if transforms is not None:
         | 
| 50 | 
            +
                    scale = transforms[:2, :2]
         | 
| 51 | 
            +
                    shift = transforms[:2, 2:3]
         | 
| 52 | 
            +
                    xy = torch.baddbmm(shift, scale, xy)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                xyz = torch.cat([xy, homo[:, 2:3, :]], 1)
         | 
| 55 | 
            +
                return xyz
         | 
    	
        PIFu/lib/mesh_util.py
    ADDED
    
    | @@ -0,0 +1,91 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from skimage import measure
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from .sdf import create_grid, eval_grid_octree, eval_grid
         | 
| 5 | 
            +
            from skimage import measure
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def reconstruction(net, cuda, calib_tensor,
         | 
| 9 | 
            +
                               resolution, b_min, b_max,
         | 
| 10 | 
            +
                               use_octree=False, num_samples=10000, transform=None):
         | 
| 11 | 
            +
                '''
         | 
| 12 | 
            +
                Reconstruct meshes from sdf predicted by the network.
         | 
| 13 | 
            +
                :param net: a BasePixImpNet object. call image filter beforehead.
         | 
| 14 | 
            +
                :param cuda: cuda device
         | 
| 15 | 
            +
                :param calib_tensor: calibration tensor
         | 
| 16 | 
            +
                :param resolution: resolution of the grid cell
         | 
| 17 | 
            +
                :param b_min: bounding box corner [x_min, y_min, z_min]
         | 
| 18 | 
            +
                :param b_max: bounding box corner [x_max, y_max, z_max]
         | 
| 19 | 
            +
                :param use_octree: whether to use octree acceleration
         | 
| 20 | 
            +
                :param num_samples: how many points to query each gpu iteration
         | 
| 21 | 
            +
                :return: marching cubes results.
         | 
| 22 | 
            +
                '''
         | 
| 23 | 
            +
                # First we create a grid by resolution
         | 
| 24 | 
            +
                # and transforming matrix for grid coordinates to real world xyz
         | 
| 25 | 
            +
                coords, mat = create_grid(resolution, resolution, resolution,
         | 
| 26 | 
            +
                                          b_min, b_max, transform=transform)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                # Then we define the lambda function for cell evaluation
         | 
| 29 | 
            +
                def eval_func(points):
         | 
| 30 | 
            +
                    points = np.expand_dims(points, axis=0)
         | 
| 31 | 
            +
                    points = np.repeat(points, net.num_views, axis=0)
         | 
| 32 | 
            +
                    samples = torch.from_numpy(points).to(device=cuda).float()
         | 
| 33 | 
            +
                    net.query(samples, calib_tensor)
         | 
| 34 | 
            +
                    pred = net.get_preds()[0][0]
         | 
| 35 | 
            +
                    return pred.detach().cpu().numpy()
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                # Then we evaluate the grid
         | 
| 38 | 
            +
                if use_octree:
         | 
| 39 | 
            +
                    sdf = eval_grid_octree(coords, eval_func, num_samples=num_samples)
         | 
| 40 | 
            +
                else:
         | 
| 41 | 
            +
                    sdf = eval_grid(coords, eval_func, num_samples=num_samples)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                # Finally we do marching cubes
         | 
| 44 | 
            +
                try:
         | 
| 45 | 
            +
                    verts, faces, normals, values = measure.marching_cubes_lewiner(sdf, 0.5)
         | 
| 46 | 
            +
                    # transform verts into world coordinate system
         | 
| 47 | 
            +
                    verts = np.matmul(mat[:3, :3], verts.T) + mat[:3, 3:4]
         | 
| 48 | 
            +
                    verts = verts.T
         | 
| 49 | 
            +
                    return verts, faces, normals, values
         | 
| 50 | 
            +
                except:
         | 
| 51 | 
            +
                    print('error cannot marching cubes')
         | 
| 52 | 
            +
                    return -1
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def save_obj_mesh(mesh_path, verts, faces):
         | 
| 56 | 
            +
                file = open(mesh_path, 'w')
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                for v in verts:
         | 
| 59 | 
            +
                    file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
         | 
| 60 | 
            +
                for f in faces:
         | 
| 61 | 
            +
                    f_plus = f + 1
         | 
| 62 | 
            +
                    file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
         | 
| 63 | 
            +
                file.close()
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
         | 
| 67 | 
            +
                file = open(mesh_path, 'w')
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                for idx, v in enumerate(verts):
         | 
| 70 | 
            +
                    c = colors[idx]
         | 
| 71 | 
            +
                    file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' % (v[0], v[1], v[2], c[0], c[1], c[2]))
         | 
| 72 | 
            +
                for f in faces:
         | 
| 73 | 
            +
                    f_plus = f + 1
         | 
| 74 | 
            +
                    file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
         | 
| 75 | 
            +
                file.close()
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def save_obj_mesh_with_uv(mesh_path, verts, faces, uvs):
         | 
| 79 | 
            +
                file = open(mesh_path, 'w')
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                for idx, v in enumerate(verts):
         | 
| 82 | 
            +
                    vt = uvs[idx]
         | 
| 83 | 
            +
                    file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
         | 
| 84 | 
            +
                    file.write('vt %.4f %.4f\n' % (vt[0], vt[1]))
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                for f in faces:
         | 
| 87 | 
            +
                    f_plus = f + 1
         | 
| 88 | 
            +
                    file.write('f %d/%d %d/%d %d/%d\n' % (f_plus[0], f_plus[0],
         | 
| 89 | 
            +
                                                          f_plus[2], f_plus[2],
         | 
| 90 | 
            +
                                                          f_plus[1], f_plus[1]))
         | 
| 91 | 
            +
                file.close()
         | 
    	
        PIFu/lib/model/BasePIFuNet.py
    ADDED
    
    | @@ -0,0 +1,76 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from ..geometry import index, orthogonal, perspective
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            class BasePIFuNet(nn.Module):
         | 
| 8 | 
            +
                def __init__(self,
         | 
| 9 | 
            +
                             projection_mode='orthogonal',
         | 
| 10 | 
            +
                             error_term=nn.MSELoss(),
         | 
| 11 | 
            +
                             ):
         | 
| 12 | 
            +
                    """
         | 
| 13 | 
            +
                    :param projection_mode:
         | 
| 14 | 
            +
                    Either orthogonal or perspective.
         | 
| 15 | 
            +
                    It will call the corresponding function for projection.
         | 
| 16 | 
            +
                    :param error_term:
         | 
| 17 | 
            +
                    nn Loss between the predicted [B, Res, N] and the label [B, Res, N]
         | 
| 18 | 
            +
                    """
         | 
| 19 | 
            +
                    super(BasePIFuNet, self).__init__()
         | 
| 20 | 
            +
                    self.name = 'base'
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    self.error_term = error_term
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    self.index = index
         | 
| 25 | 
            +
                    self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    self.preds = None
         | 
| 28 | 
            +
                    self.labels = None
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, points, images, calibs, transforms=None):
         | 
| 31 | 
            +
                    '''
         | 
| 32 | 
            +
                    :param points: [B, 3, N] world space coordinates of points
         | 
| 33 | 
            +
                    :param images: [B, C, H, W] input images
         | 
| 34 | 
            +
                    :param calibs: [B, 3, 4] calibration matrices for each image
         | 
| 35 | 
            +
                    :param transforms: Optional [B, 2, 3] image space coordinate transforms
         | 
| 36 | 
            +
                    :return: [B, Res, N] predictions for each point
         | 
| 37 | 
            +
                    '''
         | 
| 38 | 
            +
                    self.filter(images)
         | 
| 39 | 
            +
                    self.query(points, calibs, transforms)
         | 
| 40 | 
            +
                    return self.get_preds()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def filter(self, images):
         | 
| 43 | 
            +
                    '''
         | 
| 44 | 
            +
                    Filter the input images
         | 
| 45 | 
            +
                    store all intermediate features.
         | 
| 46 | 
            +
                    :param images: [B, C, H, W] input images
         | 
| 47 | 
            +
                    '''
         | 
| 48 | 
            +
                    None
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def query(self, points, calibs, transforms=None, labels=None):
         | 
| 51 | 
            +
                    '''
         | 
| 52 | 
            +
                    Given 3D points, query the network predictions for each point.
         | 
| 53 | 
            +
                    Image features should be pre-computed before this call.
         | 
| 54 | 
            +
                    store all intermediate features.
         | 
| 55 | 
            +
                    query() function may behave differently during training/testing.
         | 
| 56 | 
            +
                    :param points: [B, 3, N] world space coordinates of points
         | 
| 57 | 
            +
                    :param calibs: [B, 3, 4] calibration matrices for each image
         | 
| 58 | 
            +
                    :param transforms: Optional [B, 2, 3] image space coordinate transforms
         | 
| 59 | 
            +
                    :param labels: Optional [B, Res, N] gt labeling
         | 
| 60 | 
            +
                    :return: [B, Res, N] predictions for each point
         | 
| 61 | 
            +
                    '''
         | 
| 62 | 
            +
                    None
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def get_preds(self):
         | 
| 65 | 
            +
                    '''
         | 
| 66 | 
            +
                    Get the predictions from the last query
         | 
| 67 | 
            +
                    :return: [B, Res, N] network prediction for the last query
         | 
| 68 | 
            +
                    '''
         | 
| 69 | 
            +
                    return self.preds
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def get_error(self):
         | 
| 72 | 
            +
                    '''
         | 
| 73 | 
            +
                    Get the network loss from the last query
         | 
| 74 | 
            +
                    :return: loss term
         | 
| 75 | 
            +
                    '''
         | 
| 76 | 
            +
                    return self.error_term(self.preds, self.labels)
         | 
    	
        PIFu/lib/model/ConvFilters.py
    ADDED
    
    | @@ -0,0 +1,112 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            import torchvision.models.resnet as resnet
         | 
| 5 | 
            +
            import torchvision.models.vgg as vgg
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class MultiConv(nn.Module):
         | 
| 9 | 
            +
                def __init__(self, filter_channels):
         | 
| 10 | 
            +
                    super(MultiConv, self).__init__()
         | 
| 11 | 
            +
                    self.filters = []
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                    for l in range(0, len(filter_channels) - 1):
         | 
| 14 | 
            +
                        self.filters.append(
         | 
| 15 | 
            +
                            nn.Conv2d(filter_channels[l], filter_channels[l + 1], kernel_size=4, stride=2))
         | 
| 16 | 
            +
                        self.add_module("conv%d" % l, self.filters[l])
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def forward(self, image):
         | 
| 19 | 
            +
                    '''
         | 
| 20 | 
            +
                    :param image: [BxC_inxHxW] tensor of input image
         | 
| 21 | 
            +
                    :return: list of [BxC_outxHxW] tensors of output features
         | 
| 22 | 
            +
                    '''
         | 
| 23 | 
            +
                    y = image
         | 
| 24 | 
            +
                    # y = F.relu(self.bn0(self.conv0(y)), True)
         | 
| 25 | 
            +
                    feat_pyramid = [y]
         | 
| 26 | 
            +
                    for i, f in enumerate(self.filters):
         | 
| 27 | 
            +
                        y = f(y)
         | 
| 28 | 
            +
                        if i != len(self.filters) - 1:
         | 
| 29 | 
            +
                            y = F.leaky_relu(y)
         | 
| 30 | 
            +
                        # y = F.max_pool2d(y, kernel_size=2, stride=2)
         | 
| 31 | 
            +
                        feat_pyramid.append(y)
         | 
| 32 | 
            +
                    return feat_pyramid
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class Vgg16(torch.nn.Module):
         | 
| 36 | 
            +
                def __init__(self):
         | 
| 37 | 
            +
                    super(Vgg16, self).__init__()
         | 
| 38 | 
            +
                    vgg_pretrained_features = vgg.vgg16(pretrained=True).features
         | 
| 39 | 
            +
                    self.slice1 = torch.nn.Sequential()
         | 
| 40 | 
            +
                    self.slice2 = torch.nn.Sequential()
         | 
| 41 | 
            +
                    self.slice3 = torch.nn.Sequential()
         | 
| 42 | 
            +
                    self.slice4 = torch.nn.Sequential()
         | 
| 43 | 
            +
                    self.slice5 = torch.nn.Sequential()
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    for x in range(4):
         | 
| 46 | 
            +
                        self.slice1.add_module(str(x), vgg_pretrained_features[x])
         | 
| 47 | 
            +
                    for x in range(4, 9):
         | 
| 48 | 
            +
                        self.slice2.add_module(str(x), vgg_pretrained_features[x])
         | 
| 49 | 
            +
                    for x in range(9, 16):
         | 
| 50 | 
            +
                        self.slice3.add_module(str(x), vgg_pretrained_features[x])
         | 
| 51 | 
            +
                    for x in range(16, 23):
         | 
| 52 | 
            +
                        self.slice4.add_module(str(x), vgg_pretrained_features[x])
         | 
| 53 | 
            +
                    for x in range(23, 30):
         | 
| 54 | 
            +
                        self.slice5.add_module(str(x), vgg_pretrained_features[x])
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def forward(self, X):
         | 
| 57 | 
            +
                    h = self.slice1(X)
         | 
| 58 | 
            +
                    h_relu1_2 = h
         | 
| 59 | 
            +
                    h = self.slice2(h)
         | 
| 60 | 
            +
                    h_relu2_2 = h
         | 
| 61 | 
            +
                    h = self.slice3(h)
         | 
| 62 | 
            +
                    h_relu3_3 = h
         | 
| 63 | 
            +
                    h = self.slice4(h)
         | 
| 64 | 
            +
                    h_relu4_3 = h
         | 
| 65 | 
            +
                    h = self.slice5(h)
         | 
| 66 | 
            +
                    h_relu5_3 = h
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    return [h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            class ResNet(nn.Module):
         | 
| 72 | 
            +
                def __init__(self, model='resnet18'):
         | 
| 73 | 
            +
                    super(ResNet, self).__init__()
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    if model == 'resnet18':
         | 
| 76 | 
            +
                        net = resnet.resnet18(pretrained=True)
         | 
| 77 | 
            +
                    elif model == 'resnet34':
         | 
| 78 | 
            +
                        net = resnet.resnet34(pretrained=True)
         | 
| 79 | 
            +
                    elif model == 'resnet50':
         | 
| 80 | 
            +
                        net = resnet.resnet50(pretrained=True)
         | 
| 81 | 
            +
                    else:
         | 
| 82 | 
            +
                        raise NameError('Unknown Fan Filter setting!')
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    self.conv1 = net.conv1
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    self.pool = net.maxpool
         | 
| 87 | 
            +
                    self.layer0 = nn.Sequential(net.conv1, net.bn1, net.relu)
         | 
| 88 | 
            +
                    self.layer1 = net.layer1
         | 
| 89 | 
            +
                    self.layer2 = net.layer2
         | 
| 90 | 
            +
                    self.layer3 = net.layer3
         | 
| 91 | 
            +
                    self.layer4 = net.layer4
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def forward(self, image):
         | 
| 94 | 
            +
                    '''
         | 
| 95 | 
            +
                    :param image: [BxC_inxHxW] tensor of input image
         | 
| 96 | 
            +
                    :return: list of [BxC_outxHxW] tensors of output features
         | 
| 97 | 
            +
                    '''
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    y = image
         | 
| 100 | 
            +
                    feat_pyramid = []
         | 
| 101 | 
            +
                    y = self.layer0(y)
         | 
| 102 | 
            +
                    feat_pyramid.append(y)
         | 
| 103 | 
            +
                    y = self.layer1(self.pool(y))
         | 
| 104 | 
            +
                    feat_pyramid.append(y)
         | 
| 105 | 
            +
                    y = self.layer2(y)
         | 
| 106 | 
            +
                    feat_pyramid.append(y)
         | 
| 107 | 
            +
                    y = self.layer3(y)
         | 
| 108 | 
            +
                    feat_pyramid.append(y)
         | 
| 109 | 
            +
                    y = self.layer4(y)
         | 
| 110 | 
            +
                    feat_pyramid.append(y)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    return feat_pyramid
         | 
    	
        PIFu/lib/model/ConvPIFuNet.py
    ADDED
    
    | @@ -0,0 +1,99 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from .BasePIFuNet import BasePIFuNet
         | 
| 5 | 
            +
            from .SurfaceClassifier import SurfaceClassifier
         | 
| 6 | 
            +
            from .DepthNormalizer import DepthNormalizer
         | 
| 7 | 
            +
            from .ConvFilters import *
         | 
| 8 | 
            +
            from ..net_util import init_net
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            class ConvPIFuNet(BasePIFuNet):
         | 
| 11 | 
            +
                '''
         | 
| 12 | 
            +
                Conv Piximp network is the standard 3-phase network that we will use.
         | 
| 13 | 
            +
                The image filter is a pure multi-layer convolutional network,
         | 
| 14 | 
            +
                while during feature extraction phase all features in the pyramid at the projected location
         | 
| 15 | 
            +
                will be aggregated.
         | 
| 16 | 
            +
                It does the following:
         | 
| 17 | 
            +
                    1. Compute image feature pyramids and store it in self.im_feat_list
         | 
| 18 | 
            +
                    2. Calculate calibration and indexing on each of the feat, and append them together
         | 
| 19 | 
            +
                    3. Classification.
         | 
| 20 | 
            +
                '''
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def __init__(self,
         | 
| 23 | 
            +
                             opt,
         | 
| 24 | 
            +
                             projection_mode='orthogonal',
         | 
| 25 | 
            +
                             error_term=nn.MSELoss(),
         | 
| 26 | 
            +
                             ):
         | 
| 27 | 
            +
                    super(ConvPIFuNet, self).__init__(
         | 
| 28 | 
            +
                        projection_mode=projection_mode,
         | 
| 29 | 
            +
                        error_term=error_term)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    self.name = 'convpifu'
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    self.opt = opt
         | 
| 34 | 
            +
                    self.num_views = self.opt.num_views
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    self.image_filter = self.define_imagefilter(opt)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    self.surface_classifier = SurfaceClassifier(
         | 
| 39 | 
            +
                        filter_channels=self.opt.mlp_dim,
         | 
| 40 | 
            +
                        num_views=self.opt.num_views,
         | 
| 41 | 
            +
                        no_residual=self.opt.no_residual,
         | 
| 42 | 
            +
                        last_op=nn.Sigmoid())
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    self.normalizer = DepthNormalizer(opt)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    # This is a list of [B x Feat_i x H x W] features
         | 
| 47 | 
            +
                    self.im_feat_list = []
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    init_net(self)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def define_imagefilter(self, opt):
         | 
| 52 | 
            +
                    net = None
         | 
| 53 | 
            +
                    if opt.netIMF == 'multiconv':
         | 
| 54 | 
            +
                        net = MultiConv(opt.enc_dim)
         | 
| 55 | 
            +
                    elif 'resnet' in opt.netIMF:
         | 
| 56 | 
            +
                        net = ResNet(model=opt.netIMF)
         | 
| 57 | 
            +
                    elif opt.netIMF == 'vgg16':
         | 
| 58 | 
            +
                        net = Vgg16()
         | 
| 59 | 
            +
                    else:
         | 
| 60 | 
            +
                        raise NotImplementedError('model name [%s] is not recognized' % opt.imf_type)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    return net
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def filter(self, images):
         | 
| 65 | 
            +
                    '''
         | 
| 66 | 
            +
                    Filter the input images
         | 
| 67 | 
            +
                    store all intermediate features.
         | 
| 68 | 
            +
                    :param images: [B, C, H, W] input images
         | 
| 69 | 
            +
                    '''
         | 
| 70 | 
            +
                    self.im_feat_list = self.image_filter(images)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def query(self, points, calibs, transforms=None, labels=None):
         | 
| 73 | 
            +
                    '''
         | 
| 74 | 
            +
                    Given 3D points, query the network predictions for each point.
         | 
| 75 | 
            +
                    Image features should be pre-computed before this call.
         | 
| 76 | 
            +
                    store all intermediate features.
         | 
| 77 | 
            +
                    query() function may behave differently during training/testing.
         | 
| 78 | 
            +
                    :param points: [B, 3, N] world space coordinates of points
         | 
| 79 | 
            +
                    :param calibs: [B, 3, 4] calibration matrices for each image
         | 
| 80 | 
            +
                    :param transforms: Optional [B, 2, 3] image space coordinate transforms
         | 
| 81 | 
            +
                    :param labels: Optional [B, Res, N] gt labeling
         | 
| 82 | 
            +
                    :return: [B, Res, N] predictions for each point
         | 
| 83 | 
            +
                    '''
         | 
| 84 | 
            +
                    if labels is not None:
         | 
| 85 | 
            +
                        self.labels = labels
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    xyz = self.projection(points, calibs, transforms)
         | 
| 88 | 
            +
                    xy = xyz[:, :2, :]
         | 
| 89 | 
            +
                    z = xyz[:, 2:3, :]
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    z_feat = self.normalizer(z)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    # This is a list of [B, Feat_i, N] features
         | 
| 94 | 
            +
                    point_local_feat_list = [self.index(im_feat, xy) for im_feat in self.im_feat_list]
         | 
| 95 | 
            +
                    point_local_feat_list.append(z_feat)
         | 
| 96 | 
            +
                    # [B, Feat_all, N]
         | 
| 97 | 
            +
                    point_local_feat = torch.cat(point_local_feat_list, 1)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.preds = self.surface_classifier(point_local_feat)
         | 
    	
        PIFu/lib/model/DepthNormalizer.py
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class DepthNormalizer(nn.Module):
         | 
| 7 | 
            +
                def __init__(self, opt):
         | 
| 8 | 
            +
                    super(DepthNormalizer, self).__init__()
         | 
| 9 | 
            +
                    self.opt = opt
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                def forward(self, z, calibs=None, index_feat=None):
         | 
| 12 | 
            +
                    '''
         | 
| 13 | 
            +
                    Normalize z_feature
         | 
| 14 | 
            +
                    :param z_feat: [B, 1, N] depth value for z in the image coordinate system
         | 
| 15 | 
            +
                    :return:
         | 
| 16 | 
            +
                    '''
         | 
| 17 | 
            +
                    z_feat = z * (self.opt.loadSize // 2) / self.opt.z_size
         | 
| 18 | 
            +
                    return z_feat
         | 
    	
        PIFu/lib/model/HGFilters.py
    ADDED
    
    | @@ -0,0 +1,146 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from ..net_util import *
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class HourGlass(nn.Module):
         | 
| 8 | 
            +
                def __init__(self, num_modules, depth, num_features, norm='batch'):
         | 
| 9 | 
            +
                    super(HourGlass, self).__init__()
         | 
| 10 | 
            +
                    self.num_modules = num_modules
         | 
| 11 | 
            +
                    self.depth = depth
         | 
| 12 | 
            +
                    self.features = num_features
         | 
| 13 | 
            +
                    self.norm = norm
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                    self._generate_network(self.depth)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def _generate_network(self, level):
         | 
| 18 | 
            +
                    self.add_module('b1_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    self.add_module('b2_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    if level > 1:
         | 
| 23 | 
            +
                        self._generate_network(level - 1)
         | 
| 24 | 
            +
                    else:
         | 
| 25 | 
            +
                        self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    self.add_module('b3_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def _forward(self, level, inp):
         | 
| 30 | 
            +
                    # Upper branch
         | 
| 31 | 
            +
                    up1 = inp
         | 
| 32 | 
            +
                    up1 = self._modules['b1_' + str(level)](up1)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    # Lower branch
         | 
| 35 | 
            +
                    low1 = F.avg_pool2d(inp, 2, stride=2)
         | 
| 36 | 
            +
                    low1 = self._modules['b2_' + str(level)](low1)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    if level > 1:
         | 
| 39 | 
            +
                        low2 = self._forward(level - 1, low1)
         | 
| 40 | 
            +
                    else:
         | 
| 41 | 
            +
                        low2 = low1
         | 
| 42 | 
            +
                        low2 = self._modules['b2_plus_' + str(level)](low2)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    low3 = low2
         | 
| 45 | 
            +
                    low3 = self._modules['b3_' + str(level)](low3)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    # NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample
         | 
| 48 | 
            +
                    # if the pretrained model behaves weirdly, switch with the commented line.
         | 
| 49 | 
            +
                    # NOTE: I also found that "bicubic" works better.
         | 
| 50 | 
            +
                    up2 = F.interpolate(low3, scale_factor=2, mode='bicubic', align_corners=True)
         | 
| 51 | 
            +
                    # up2 = F.interpolate(low3, scale_factor=2, mode='nearest)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    return up1 + up2
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def forward(self, x):
         | 
| 56 | 
            +
                    return self._forward(self.depth, x)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            class HGFilter(nn.Module):
         | 
| 60 | 
            +
                def __init__(self, opt):
         | 
| 61 | 
            +
                    super(HGFilter, self).__init__()
         | 
| 62 | 
            +
                    self.num_modules = opt.num_stack
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    self.opt = opt
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    # Base part
         | 
| 67 | 
            +
                    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    if self.opt.norm == 'batch':
         | 
| 70 | 
            +
                        self.bn1 = nn.BatchNorm2d(64)
         | 
| 71 | 
            +
                    elif self.opt.norm == 'group':
         | 
| 72 | 
            +
                        self.bn1 = nn.GroupNorm(32, 64)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    if self.opt.hg_down == 'conv64':
         | 
| 75 | 
            +
                        self.conv2 = ConvBlock(64, 64, self.opt.norm)
         | 
| 76 | 
            +
                        self.down_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
         | 
| 77 | 
            +
                    elif self.opt.hg_down == 'conv128':
         | 
| 78 | 
            +
                        self.conv2 = ConvBlock(64, 128, self.opt.norm)
         | 
| 79 | 
            +
                        self.down_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
         | 
| 80 | 
            +
                    elif self.opt.hg_down == 'ave_pool':
         | 
| 81 | 
            +
                        self.conv2 = ConvBlock(64, 128, self.opt.norm)
         | 
| 82 | 
            +
                    else:
         | 
| 83 | 
            +
                        raise NameError('Unknown Fan Filter setting!')
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    self.conv3 = ConvBlock(128, 128, self.opt.norm)
         | 
| 86 | 
            +
                    self.conv4 = ConvBlock(128, 256, self.opt.norm)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Stacking part
         | 
| 89 | 
            +
                    for hg_module in range(self.num_modules):
         | 
| 90 | 
            +
                        self.add_module('m' + str(hg_module), HourGlass(1, opt.num_hourglass, 256, self.opt.norm))
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                        self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256, self.opt.norm))
         | 
| 93 | 
            +
                        self.add_module('conv_last' + str(hg_module),
         | 
| 94 | 
            +
                                        nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
         | 
| 95 | 
            +
                        if self.opt.norm == 'batch':
         | 
| 96 | 
            +
                            self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
         | 
| 97 | 
            +
                        elif self.opt.norm == 'group':
         | 
| 98 | 
            +
                            self.add_module('bn_end' + str(hg_module), nn.GroupNorm(32, 256))
         | 
| 99 | 
            +
                            
         | 
| 100 | 
            +
                        self.add_module('l' + str(hg_module), nn.Conv2d(256,
         | 
| 101 | 
            +
                                                                        opt.hourglass_dim, kernel_size=1, stride=1, padding=0))
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                        if hg_module < self.num_modules - 1:
         | 
| 104 | 
            +
                            self.add_module(
         | 
| 105 | 
            +
                                'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
         | 
| 106 | 
            +
                            self.add_module('al' + str(hg_module), nn.Conv2d(opt.hourglass_dim,
         | 
| 107 | 
            +
                                                                             256, kernel_size=1, stride=1, padding=0))
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def forward(self, x):
         | 
| 110 | 
            +
                    x = F.relu(self.bn1(self.conv1(x)), True)
         | 
| 111 | 
            +
                    tmpx = x
         | 
| 112 | 
            +
                    if self.opt.hg_down == 'ave_pool':
         | 
| 113 | 
            +
                        x = F.avg_pool2d(self.conv2(x), 2, stride=2)
         | 
| 114 | 
            +
                    elif self.opt.hg_down in ['conv64', 'conv128']:
         | 
| 115 | 
            +
                        x = self.conv2(x)
         | 
| 116 | 
            +
                        x = self.down_conv2(x)
         | 
| 117 | 
            +
                    else:
         | 
| 118 | 
            +
                        raise NameError('Unknown Fan Filter setting!')
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    normx = x
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    x = self.conv3(x)
         | 
| 123 | 
            +
                    x = self.conv4(x)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    previous = x
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    outputs = []
         | 
| 128 | 
            +
                    for i in range(self.num_modules):
         | 
| 129 | 
            +
                        hg = self._modules['m' + str(i)](previous)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                        ll = hg
         | 
| 132 | 
            +
                        ll = self._modules['top_m_' + str(i)](ll)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                        ll = F.relu(self._modules['bn_end' + str(i)]
         | 
| 135 | 
            +
                                    (self._modules['conv_last' + str(i)](ll)), True)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                        # Predict heatmaps
         | 
| 138 | 
            +
                        tmp_out = self._modules['l' + str(i)](ll)
         | 
| 139 | 
            +
                        outputs.append(tmp_out)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                        if i < self.num_modules - 1:
         | 
| 142 | 
            +
                            ll = self._modules['bl' + str(i)](ll)
         | 
| 143 | 
            +
                            tmp_out_ = self._modules['al' + str(i)](tmp_out)
         | 
| 144 | 
            +
                            previous = previous + ll + tmp_out_
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    return outputs, tmpx.detach(), normx
         | 
    	
        PIFu/lib/model/HGPIFuNet.py
    ADDED
    
    | @@ -0,0 +1,142 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from .BasePIFuNet import BasePIFuNet
         | 
| 5 | 
            +
            from .SurfaceClassifier import SurfaceClassifier
         | 
| 6 | 
            +
            from .DepthNormalizer import DepthNormalizer
         | 
| 7 | 
            +
            from .HGFilters import *
         | 
| 8 | 
            +
            from ..net_util import init_net
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class HGPIFuNet(BasePIFuNet):
         | 
| 12 | 
            +
                '''
         | 
| 13 | 
            +
                HG PIFu network uses Hourglass stacks as the image filter.
         | 
| 14 | 
            +
                It does the following:
         | 
| 15 | 
            +
                    1. Compute image feature stacks and store it in self.im_feat_list
         | 
| 16 | 
            +
                        self.im_feat_list[-1] is the last stack (output stack)
         | 
| 17 | 
            +
                    2. Calculate calibration
         | 
| 18 | 
            +
                    3. If training, it index on every intermediate stacks,
         | 
| 19 | 
            +
                        If testing, it index on the last stack.
         | 
| 20 | 
            +
                    4. Classification.
         | 
| 21 | 
            +
                    5. During training, error is calculated on all stacks.
         | 
| 22 | 
            +
                '''
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(self,
         | 
| 25 | 
            +
                             opt,
         | 
| 26 | 
            +
                             projection_mode='orthogonal',
         | 
| 27 | 
            +
                             error_term=nn.MSELoss(),
         | 
| 28 | 
            +
                             ):
         | 
| 29 | 
            +
                    super(HGPIFuNet, self).__init__(
         | 
| 30 | 
            +
                        projection_mode=projection_mode,
         | 
| 31 | 
            +
                        error_term=error_term)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    self.name = 'hgpifu'
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    self.opt = opt
         | 
| 36 | 
            +
                    self.num_views = self.opt.num_views
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    self.image_filter = HGFilter(opt)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    self.surface_classifier = SurfaceClassifier(
         | 
| 41 | 
            +
                        filter_channels=self.opt.mlp_dim,
         | 
| 42 | 
            +
                        num_views=self.opt.num_views,
         | 
| 43 | 
            +
                        no_residual=self.opt.no_residual,
         | 
| 44 | 
            +
                        last_op=nn.Sigmoid())
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    self.normalizer = DepthNormalizer(opt)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    # This is a list of [B x Feat_i x H x W] features
         | 
| 49 | 
            +
                    self.im_feat_list = []
         | 
| 50 | 
            +
                    self.tmpx = None
         | 
| 51 | 
            +
                    self.normx = None
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    self.intermediate_preds_list = []
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    init_net(self)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def filter(self, images):
         | 
| 58 | 
            +
                    '''
         | 
| 59 | 
            +
                    Filter the input images
         | 
| 60 | 
            +
                    store all intermediate features.
         | 
| 61 | 
            +
                    :param images: [B, C, H, W] input images
         | 
| 62 | 
            +
                    '''
         | 
| 63 | 
            +
                    self.im_feat_list, self.tmpx, self.normx = self.image_filter(images)
         | 
| 64 | 
            +
                    # If it is not in training, only produce the last im_feat
         | 
| 65 | 
            +
                    if not self.training:
         | 
| 66 | 
            +
                        self.im_feat_list = [self.im_feat_list[-1]]
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def query(self, points, calibs, transforms=None, labels=None):
         | 
| 69 | 
            +
                    '''
         | 
| 70 | 
            +
                    Given 3D points, query the network predictions for each point.
         | 
| 71 | 
            +
                    Image features should be pre-computed before this call.
         | 
| 72 | 
            +
                    store all intermediate features.
         | 
| 73 | 
            +
                    query() function may behave differently during training/testing.
         | 
| 74 | 
            +
                    :param points: [B, 3, N] world space coordinates of points
         | 
| 75 | 
            +
                    :param calibs: [B, 3, 4] calibration matrices for each image
         | 
| 76 | 
            +
                    :param transforms: Optional [B, 2, 3] image space coordinate transforms
         | 
| 77 | 
            +
                    :param labels: Optional [B, Res, N] gt labeling
         | 
| 78 | 
            +
                    :return: [B, Res, N] predictions for each point
         | 
| 79 | 
            +
                    '''
         | 
| 80 | 
            +
                    if labels is not None:
         | 
| 81 | 
            +
                        self.labels = labels
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    xyz = self.projection(points, calibs, transforms)
         | 
| 84 | 
            +
                    xy = xyz[:, :2, :]
         | 
| 85 | 
            +
                    z = xyz[:, 2:3, :]
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    in_img = (xy[:, 0] >= -1.0) & (xy[:, 0] <= 1.0) & (xy[:, 1] >= -1.0) & (xy[:, 1] <= 1.0)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    z_feat = self.normalizer(z, calibs=calibs)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    if self.opt.skip_hourglass:
         | 
| 92 | 
            +
                        tmpx_local_feature = self.index(self.tmpx, xy)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    self.intermediate_preds_list = []
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    for im_feat in self.im_feat_list:
         | 
| 97 | 
            +
                        # [B, Feat_i + z, N]
         | 
| 98 | 
            +
                        point_local_feat_list = [self.index(im_feat, xy), z_feat]
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                        if self.opt.skip_hourglass:
         | 
| 101 | 
            +
                            point_local_feat_list.append(tmpx_local_feature)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                        point_local_feat = torch.cat(point_local_feat_list, 1)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                        # out of image plane is always set to 0
         | 
| 106 | 
            +
                        pred = in_img[:,None].float() * self.surface_classifier(point_local_feat)
         | 
| 107 | 
            +
                        self.intermediate_preds_list.append(pred)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    self.preds = self.intermediate_preds_list[-1]
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def get_im_feat(self):
         | 
| 112 | 
            +
                    '''
         | 
| 113 | 
            +
                    Get the image filter
         | 
| 114 | 
            +
                    :return: [B, C_feat, H, W] image feature after filtering
         | 
| 115 | 
            +
                    '''
         | 
| 116 | 
            +
                    return self.im_feat_list[-1]
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def get_error(self):
         | 
| 119 | 
            +
                    '''
         | 
| 120 | 
            +
                    Hourglass has its own intermediate supervision scheme
         | 
| 121 | 
            +
                    '''
         | 
| 122 | 
            +
                    error = 0
         | 
| 123 | 
            +
                    for preds in self.intermediate_preds_list:
         | 
| 124 | 
            +
                        error += self.error_term(preds, self.labels)
         | 
| 125 | 
            +
                    error /= len(self.intermediate_preds_list)
         | 
| 126 | 
            +
                    
         | 
| 127 | 
            +
                    return error
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def forward(self, images, points, calibs, transforms=None, labels=None):
         | 
| 130 | 
            +
                    # Get image feature
         | 
| 131 | 
            +
                    self.filter(images)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # Phase 2: point query
         | 
| 134 | 
            +
                    self.query(points=points, calibs=calibs, transforms=transforms, labels=labels)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    # get the prediction
         | 
| 137 | 
            +
                    res = self.get_preds()
         | 
| 138 | 
            +
                    
         | 
| 139 | 
            +
                    # get the error
         | 
| 140 | 
            +
                    error = self.get_error()
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    return res, error
         | 
    	
        PIFu/lib/model/ResBlkPIFuNet.py
    ADDED
    
    | @@ -0,0 +1,201 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from .BasePIFuNet import BasePIFuNet
         | 
| 5 | 
            +
            import functools
         | 
| 6 | 
            +
            from .SurfaceClassifier import SurfaceClassifier
         | 
| 7 | 
            +
            from .DepthNormalizer import DepthNormalizer
         | 
| 8 | 
            +
            from ..net_util import *
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class ResBlkPIFuNet(BasePIFuNet):
         | 
| 12 | 
            +
                def __init__(self, opt,
         | 
| 13 | 
            +
                             projection_mode='orthogonal'):
         | 
| 14 | 
            +
                    if opt.color_loss_type == 'l1':
         | 
| 15 | 
            +
                        error_term = nn.L1Loss()
         | 
| 16 | 
            +
                    elif opt.color_loss_type == 'mse':
         | 
| 17 | 
            +
                        error_term = nn.MSELoss()
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    super(ResBlkPIFuNet, self).__init__(
         | 
| 20 | 
            +
                        projection_mode=projection_mode,
         | 
| 21 | 
            +
                        error_term=error_term)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    self.name = 'respifu'
         | 
| 24 | 
            +
                    self.opt = opt
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    norm_type = get_norm_layer(norm_type=opt.norm_color)
         | 
| 27 | 
            +
                    self.image_filter = ResnetFilter(opt, norm_layer=norm_type)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    self.surface_classifier = SurfaceClassifier(
         | 
| 30 | 
            +
                        filter_channels=self.opt.mlp_dim_color,
         | 
| 31 | 
            +
                        num_views=self.opt.num_views,
         | 
| 32 | 
            +
                        no_residual=self.opt.no_residual,
         | 
| 33 | 
            +
                        last_op=nn.Tanh())
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    self.normalizer = DepthNormalizer(opt)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    init_net(self)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def filter(self, images):
         | 
| 40 | 
            +
                    '''
         | 
| 41 | 
            +
                    Filter the input images
         | 
| 42 | 
            +
                    store all intermediate features.
         | 
| 43 | 
            +
                    :param images: [B, C, H, W] input images
         | 
| 44 | 
            +
                    '''
         | 
| 45 | 
            +
                    self.im_feat = self.image_filter(images)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def attach(self, im_feat):
         | 
| 48 | 
            +
                    self.im_feat = torch.cat([im_feat, self.im_feat], 1)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def query(self, points, calibs, transforms=None, labels=None):
         | 
| 51 | 
            +
                    '''
         | 
| 52 | 
            +
                    Given 3D points, query the network predictions for each point.
         | 
| 53 | 
            +
                    Image features should be pre-computed before this call.
         | 
| 54 | 
            +
                    store all intermediate features.
         | 
| 55 | 
            +
                    query() function may behave differently during training/testing.
         | 
| 56 | 
            +
                    :param points: [B, 3, N] world space coordinates of points
         | 
| 57 | 
            +
                    :param calibs: [B, 3, 4] calibration matrices for each image
         | 
| 58 | 
            +
                    :param transforms: Optional [B, 2, 3] image space coordinate transforms
         | 
| 59 | 
            +
                    :param labels: Optional [B, Res, N] gt labeling
         | 
| 60 | 
            +
                    :return: [B, Res, N] predictions for each point
         | 
| 61 | 
            +
                    '''
         | 
| 62 | 
            +
                    if labels is not None:
         | 
| 63 | 
            +
                        self.labels = labels
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    xyz = self.projection(points, calibs, transforms)
         | 
| 66 | 
            +
                    xy = xyz[:, :2, :]
         | 
| 67 | 
            +
                    z = xyz[:, 2:3, :]
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    z_feat = self.normalizer(z)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    # This is a list of [B, Feat_i, N] features
         | 
| 72 | 
            +
                    point_local_feat_list = [self.index(self.im_feat, xy), z_feat]
         | 
| 73 | 
            +
                    # [B, Feat_all, N]
         | 
| 74 | 
            +
                    point_local_feat = torch.cat(point_local_feat_list, 1)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self.preds = self.surface_classifier(point_local_feat)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def forward(self, images, im_feat, points, calibs, transforms=None, labels=None):
         | 
| 79 | 
            +
                    self.filter(images)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    self.attach(im_feat)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    self.query(points, calibs, transforms, labels)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    res = self.get_preds()
         | 
| 86 | 
            +
                    error = self.get_error()
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    return res, error
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            class ResnetBlock(nn.Module):
         | 
| 91 | 
            +
                """Define a Resnet block"""
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False):
         | 
| 94 | 
            +
                    """Initialize the Resnet block
         | 
| 95 | 
            +
                    A resnet block is a conv block with skip connections
         | 
| 96 | 
            +
                    We construct a conv block with build_conv_block function,
         | 
| 97 | 
            +
                    and implement skip connections in <forward> function.
         | 
| 98 | 
            +
                    Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
         | 
| 99 | 
            +
                    """
         | 
| 100 | 
            +
                    super(ResnetBlock, self).__init__()
         | 
| 101 | 
            +
                    self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, last)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False):
         | 
| 104 | 
            +
                    """Construct a convolutional block.
         | 
| 105 | 
            +
                    Parameters:
         | 
| 106 | 
            +
                        dim (int)           -- the number of channels in the conv layer.
         | 
| 107 | 
            +
                        padding_type (str)  -- the name of padding layer: reflect | replicate | zero
         | 
| 108 | 
            +
                        norm_layer          -- normalization layer
         | 
| 109 | 
            +
                        use_dropout (bool)  -- if use dropout layers.
         | 
| 110 | 
            +
                        use_bias (bool)     -- if the conv layer uses bias or not
         | 
| 111 | 
            +
                    Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
         | 
| 112 | 
            +
                    """
         | 
| 113 | 
            +
                    conv_block = []
         | 
| 114 | 
            +
                    p = 0
         | 
| 115 | 
            +
                    if padding_type == 'reflect':
         | 
| 116 | 
            +
                        conv_block += [nn.ReflectionPad2d(1)]
         | 
| 117 | 
            +
                    elif padding_type == 'replicate':
         | 
| 118 | 
            +
                        conv_block += [nn.ReplicationPad2d(1)]
         | 
| 119 | 
            +
                    elif padding_type == 'zero':
         | 
| 120 | 
            +
                        p = 1
         | 
| 121 | 
            +
                    else:
         | 
| 122 | 
            +
                        raise NotImplementedError('padding [%s] is not implemented' % padding_type)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
         | 
| 125 | 
            +
                    if use_dropout:
         | 
| 126 | 
            +
                        conv_block += [nn.Dropout(0.5)]
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    p = 0
         | 
| 129 | 
            +
                    if padding_type == 'reflect':
         | 
| 130 | 
            +
                        conv_block += [nn.ReflectionPad2d(1)]
         | 
| 131 | 
            +
                    elif padding_type == 'replicate':
         | 
| 132 | 
            +
                        conv_block += [nn.ReplicationPad2d(1)]
         | 
| 133 | 
            +
                    elif padding_type == 'zero':
         | 
| 134 | 
            +
                        p = 1
         | 
| 135 | 
            +
                    else:
         | 
| 136 | 
            +
                        raise NotImplementedError('padding [%s] is not implemented' % padding_type)
         | 
| 137 | 
            +
                    if last:
         | 
| 138 | 
            +
                        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
         | 
| 139 | 
            +
                    else:
         | 
| 140 | 
            +
                        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    return nn.Sequential(*conv_block)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def forward(self, x):
         | 
| 145 | 
            +
                    """Forward function (with skip connections)"""
         | 
| 146 | 
            +
                    out = x + self.conv_block(x)  # add skip connections
         | 
| 147 | 
            +
                    return out
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            class ResnetFilter(nn.Module):
         | 
| 151 | 
            +
                """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
         | 
| 152 | 
            +
                We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
         | 
| 153 | 
            +
                """
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def __init__(self, opt, input_nc=3, output_nc=256, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
         | 
| 156 | 
            +
                             n_blocks=6, padding_type='reflect'):
         | 
| 157 | 
            +
                    """Construct a Resnet-based generator
         | 
| 158 | 
            +
                    Parameters:
         | 
| 159 | 
            +
                        input_nc (int)      -- the number of channels in input images
         | 
| 160 | 
            +
                        output_nc (int)     -- the number of channels in output images
         | 
| 161 | 
            +
                        ngf (int)           -- the number of filters in the last conv layer
         | 
| 162 | 
            +
                        norm_layer          -- normalization layer
         | 
| 163 | 
            +
                        use_dropout (bool)  -- if use dropout layers
         | 
| 164 | 
            +
                        n_blocks (int)      -- the number of ResNet blocks
         | 
| 165 | 
            +
                        padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
         | 
| 166 | 
            +
                    """
         | 
| 167 | 
            +
                    assert (n_blocks >= 0)
         | 
| 168 | 
            +
                    super(ResnetFilter, self).__init__()
         | 
| 169 | 
            +
                    if type(norm_layer) == functools.partial:
         | 
| 170 | 
            +
                        use_bias = norm_layer.func == nn.InstanceNorm2d
         | 
| 171 | 
            +
                    else:
         | 
| 172 | 
            +
                        use_bias = norm_layer == nn.InstanceNorm2d
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    model = [nn.ReflectionPad2d(3),
         | 
| 175 | 
            +
                             nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
         | 
| 176 | 
            +
                             norm_layer(ngf),
         | 
| 177 | 
            +
                             nn.ReLU(True)]
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    n_downsampling = 2
         | 
| 180 | 
            +
                    for i in range(n_downsampling):  # add downsampling layers
         | 
| 181 | 
            +
                        mult = 2 ** i
         | 
| 182 | 
            +
                        model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
         | 
| 183 | 
            +
                                  norm_layer(ngf * mult * 2),
         | 
| 184 | 
            +
                                  nn.ReLU(True)]
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    mult = 2 ** n_downsampling
         | 
| 187 | 
            +
                    for i in range(n_blocks):  # add ResNet blocks
         | 
| 188 | 
            +
                        if i == n_blocks - 1:
         | 
| 189 | 
            +
                            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
         | 
| 190 | 
            +
                                                  use_dropout=use_dropout, use_bias=use_bias, last=True)]
         | 
| 191 | 
            +
                        else:
         | 
| 192 | 
            +
                            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
         | 
| 193 | 
            +
                                                  use_dropout=use_dropout, use_bias=use_bias)]
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    if opt.use_tanh:
         | 
| 196 | 
            +
                        model += [nn.Tanh()]
         | 
| 197 | 
            +
                    self.model = nn.Sequential(*model)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def forward(self, input):
         | 
| 200 | 
            +
                    """Standard forward"""
         | 
| 201 | 
            +
                    return self.model(input)
         | 
    	
        PIFu/lib/model/SurfaceClassifier.py
    ADDED
    
    | @@ -0,0 +1,71 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class SurfaceClassifier(nn.Module):
         | 
| 7 | 
            +
                def __init__(self, filter_channels, num_views=1, no_residual=True, last_op=None):
         | 
| 8 | 
            +
                    super(SurfaceClassifier, self).__init__()
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                    self.filters = []
         | 
| 11 | 
            +
                    self.num_views = num_views
         | 
| 12 | 
            +
                    self.no_residual = no_residual
         | 
| 13 | 
            +
                    filter_channels = filter_channels
         | 
| 14 | 
            +
                    self.last_op = last_op
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    if self.no_residual:
         | 
| 17 | 
            +
                        for l in range(0, len(filter_channels) - 1):
         | 
| 18 | 
            +
                            self.filters.append(nn.Conv1d(
         | 
| 19 | 
            +
                                filter_channels[l],
         | 
| 20 | 
            +
                                filter_channels[l + 1],
         | 
| 21 | 
            +
                                1))
         | 
| 22 | 
            +
                            self.add_module("conv%d" % l, self.filters[l])
         | 
| 23 | 
            +
                    else:
         | 
| 24 | 
            +
                        for l in range(0, len(filter_channels) - 1):
         | 
| 25 | 
            +
                            if 0 != l:
         | 
| 26 | 
            +
                                self.filters.append(
         | 
| 27 | 
            +
                                    nn.Conv1d(
         | 
| 28 | 
            +
                                        filter_channels[l] + filter_channels[0],
         | 
| 29 | 
            +
                                        filter_channels[l + 1],
         | 
| 30 | 
            +
                                        1))
         | 
| 31 | 
            +
                            else:
         | 
| 32 | 
            +
                                self.filters.append(nn.Conv1d(
         | 
| 33 | 
            +
                                    filter_channels[l],
         | 
| 34 | 
            +
                                    filter_channels[l + 1],
         | 
| 35 | 
            +
                                    1))
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                            self.add_module("conv%d" % l, self.filters[l])
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def forward(self, feature):
         | 
| 40 | 
            +
                    '''
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    :param feature: list of [BxC_inxHxW] tensors of image features
         | 
| 43 | 
            +
                    :param xy: [Bx3xN] tensor of (x,y) coodinates in the image plane
         | 
| 44 | 
            +
                    :return: [BxC_outxN] tensor of features extracted at the coordinates
         | 
| 45 | 
            +
                    '''
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    y = feature
         | 
| 48 | 
            +
                    tmpy = feature
         | 
| 49 | 
            +
                    for i, f in enumerate(self.filters):
         | 
| 50 | 
            +
                        if self.no_residual:
         | 
| 51 | 
            +
                            y = self._modules['conv' + str(i)](y)
         | 
| 52 | 
            +
                        else:
         | 
| 53 | 
            +
                            y = self._modules['conv' + str(i)](
         | 
| 54 | 
            +
                                y if i == 0
         | 
| 55 | 
            +
                                else torch.cat([y, tmpy], 1)
         | 
| 56 | 
            +
                            )
         | 
| 57 | 
            +
                        if i != len(self.filters) - 1:
         | 
| 58 | 
            +
                            y = F.leaky_relu(y)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                        if self.num_views > 1 and i == len(self.filters) // 2:
         | 
| 61 | 
            +
                            y = y.view(
         | 
| 62 | 
            +
                                -1, self.num_views, y.shape[1], y.shape[2]
         | 
| 63 | 
            +
                            ).mean(dim=1)
         | 
| 64 | 
            +
                            tmpy = feature.view(
         | 
| 65 | 
            +
                                -1, self.num_views, feature.shape[1], feature.shape[2]
         | 
| 66 | 
            +
                            ).mean(dim=1)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    if self.last_op:
         | 
| 69 | 
            +
                        y = self.last_op(y)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    return y
         | 
    	
        PIFu/lib/model/VhullPIFuNet.py
    ADDED
    
    | @@ -0,0 +1,70 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from .BasePIFuNet import BasePIFuNet
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class VhullPIFuNet(BasePIFuNet):
         | 
| 8 | 
            +
                '''
         | 
| 9 | 
            +
                Vhull Piximp network is a minimal network demonstrating how the template works
         | 
| 10 | 
            +
                also, it helps debugging the training/test schemes
         | 
| 11 | 
            +
                It does the following:
         | 
| 12 | 
            +
                    1. Compute the masks of images and stores under self.im_feats
         | 
| 13 | 
            +
                    2. Calculate calibration and indexing
         | 
| 14 | 
            +
                    3. Return if the points fall into the intersection of all masks
         | 
| 15 | 
            +
                '''
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def __init__(self,
         | 
| 18 | 
            +
                             num_views,
         | 
| 19 | 
            +
                             projection_mode='orthogonal',
         | 
| 20 | 
            +
                             error_term=nn.MSELoss(),
         | 
| 21 | 
            +
                             ):
         | 
| 22 | 
            +
                    super(VhullPIFuNet, self).__init__(
         | 
| 23 | 
            +
                        projection_mode=projection_mode,
         | 
| 24 | 
            +
                        error_term=error_term)
         | 
| 25 | 
            +
                    self.name = 'vhull'
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    self.num_views = num_views
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    self.im_feat = None
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def filter(self, images):
         | 
| 32 | 
            +
                    '''
         | 
| 33 | 
            +
                    Filter the input images
         | 
| 34 | 
            +
                    store all intermediate features.
         | 
| 35 | 
            +
                    :param images: [B, C, H, W] input images
         | 
| 36 | 
            +
                    '''
         | 
| 37 | 
            +
                    # If the image has alpha channel, use the alpha channel
         | 
| 38 | 
            +
                    if images.shape[1] > 3:
         | 
| 39 | 
            +
                        self.im_feat = images[:, 3:4, :, :]
         | 
| 40 | 
            +
                    # Else, tell if it's not white
         | 
| 41 | 
            +
                    else:
         | 
| 42 | 
            +
                        self.im_feat = images[:, 0:1, :, :]
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def query(self, points, calibs, transforms=None, labels=None):
         | 
| 45 | 
            +
                    '''
         | 
| 46 | 
            +
                    Given 3D points, query the network predictions for each point.
         | 
| 47 | 
            +
                    Image features should be pre-computed before this call.
         | 
| 48 | 
            +
                    store all intermediate features.
         | 
| 49 | 
            +
                    query() function may behave differently during training/testing.
         | 
| 50 | 
            +
                    :param points: [B, 3, N] world space coordinates of points
         | 
| 51 | 
            +
                    :param calibs: [B, 3, 4] calibration matrices for each image
         | 
| 52 | 
            +
                    :param transforms: Optional [B, 2, 3] image space coordinate transforms
         | 
| 53 | 
            +
                    :param labels: Optional [B, Res, N] gt labeling
         | 
| 54 | 
            +
                    :return: [B, Res, N] predictions for each point
         | 
| 55 | 
            +
                    '''
         | 
| 56 | 
            +
                    if labels is not None:
         | 
| 57 | 
            +
                        self.labels = labels
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    xyz = self.projection(points, calibs, transforms)
         | 
| 60 | 
            +
                    xy = xyz[:, :2, :]
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    point_local_feat = self.index(self.im_feat, xy)
         | 
| 63 | 
            +
                    local_shape = point_local_feat.shape
         | 
| 64 | 
            +
                    point_feat = point_local_feat.view(
         | 
| 65 | 
            +
                        local_shape[0] // self.num_views,
         | 
| 66 | 
            +
                        local_shape[1] * self.num_views,
         | 
| 67 | 
            +
                        -1)
         | 
| 68 | 
            +
                    pred = torch.prod(point_feat, dim=1)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    self.preds = pred.unsqueeze(1)
         | 
    	
        PIFu/lib/model/__init__.py
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .BasePIFuNet import BasePIFuNet
         | 
| 2 | 
            +
            from .VhullPIFuNet import VhullPIFuNet
         | 
| 3 | 
            +
            from .ConvPIFuNet import ConvPIFuNet
         | 
| 4 | 
            +
            from .HGPIFuNet import HGPIFuNet
         | 
| 5 | 
            +
            from .ResBlkPIFuNet import ResBlkPIFuNet
         | 
    	
        PIFu/lib/net_util.py
    ADDED
    
    | @@ -0,0 +1,396 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch.nn import init
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            import functools
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            from .mesh_util import *
         | 
| 9 | 
            +
            from .sample_util import *
         | 
| 10 | 
            +
            from .geometry import index
         | 
| 11 | 
            +
            import cv2
         | 
| 12 | 
            +
            from PIL import Image
         | 
| 13 | 
            +
            from tqdm import tqdm
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def reshape_multiview_tensors(image_tensor, calib_tensor):
         | 
| 17 | 
            +
                # Careful here! Because we put single view and multiview together,
         | 
| 18 | 
            +
                # the returned tensor.shape is 5-dim: [B, num_views, C, W, H]
         | 
| 19 | 
            +
                # So we need to convert it back to 4-dim [B*num_views, C, W, H]
         | 
| 20 | 
            +
                # Don't worry classifier will handle multi-view cases
         | 
| 21 | 
            +
                image_tensor = image_tensor.view(
         | 
| 22 | 
            +
                    image_tensor.shape[0] * image_tensor.shape[1],
         | 
| 23 | 
            +
                    image_tensor.shape[2],
         | 
| 24 | 
            +
                    image_tensor.shape[3],
         | 
| 25 | 
            +
                    image_tensor.shape[4]
         | 
| 26 | 
            +
                )
         | 
| 27 | 
            +
                calib_tensor = calib_tensor.view(
         | 
| 28 | 
            +
                    calib_tensor.shape[0] * calib_tensor.shape[1],
         | 
| 29 | 
            +
                    calib_tensor.shape[2],
         | 
| 30 | 
            +
                    calib_tensor.shape[3]
         | 
| 31 | 
            +
                )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                return image_tensor, calib_tensor
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def reshape_sample_tensor(sample_tensor, num_views):
         | 
| 37 | 
            +
                if num_views == 1:
         | 
| 38 | 
            +
                    return sample_tensor
         | 
| 39 | 
            +
                # Need to repeat sample_tensor along the batch dim num_views times
         | 
| 40 | 
            +
                sample_tensor = sample_tensor.unsqueeze(dim=1)
         | 
| 41 | 
            +
                sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
         | 
| 42 | 
            +
                sample_tensor = sample_tensor.view(
         | 
| 43 | 
            +
                    sample_tensor.shape[0] * sample_tensor.shape[1],
         | 
| 44 | 
            +
                    sample_tensor.shape[2],
         | 
| 45 | 
            +
                    sample_tensor.shape[3]
         | 
| 46 | 
            +
                )
         | 
| 47 | 
            +
                return sample_tensor
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def gen_mesh(opt, net, cuda, data, save_path, use_octree=True):
         | 
| 51 | 
            +
                image_tensor = data['img'].to(device=cuda)
         | 
| 52 | 
            +
                calib_tensor = data['calib'].to(device=cuda)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                net.filter(image_tensor)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                b_min = data['b_min']
         | 
| 57 | 
            +
                b_max = data['b_max']
         | 
| 58 | 
            +
                try:
         | 
| 59 | 
            +
                    save_img_path = save_path[:-4] + '.png'
         | 
| 60 | 
            +
                    save_img_list = []
         | 
| 61 | 
            +
                    for v in range(image_tensor.shape[0]):
         | 
| 62 | 
            +
                        save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
         | 
| 63 | 
            +
                        save_img_list.append(save_img)
         | 
| 64 | 
            +
                    save_img = np.concatenate(save_img_list, axis=1)
         | 
| 65 | 
            +
                    Image.fromarray(np.uint8(save_img[:,:,::-1])).save(save_img_path)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    verts, faces, _, _ = reconstruction(
         | 
| 68 | 
            +
                        net, cuda, calib_tensor, opt.resolution, b_min, b_max, use_octree=use_octree)
         | 
| 69 | 
            +
                    verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(device=cuda).float()
         | 
| 70 | 
            +
                    xyz_tensor = net.projection(verts_tensor, calib_tensor[:1])
         | 
| 71 | 
            +
                    uv = xyz_tensor[:, :2, :]
         | 
| 72 | 
            +
                    color = index(image_tensor[:1], uv).detach().cpu().numpy()[0].T
         | 
| 73 | 
            +
                    color = color * 0.5 + 0.5
         | 
| 74 | 
            +
                    save_obj_mesh_with_color(save_path, verts, faces, color)
         | 
| 75 | 
            +
                except Exception as e:
         | 
| 76 | 
            +
                    print(e)
         | 
| 77 | 
            +
                    print('Can not create marching cubes at this time.')
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            def gen_mesh_color(opt, netG, netC, cuda, data, save_path, use_octree=True):
         | 
| 80 | 
            +
                image_tensor = data['img'].to(device=cuda)
         | 
| 81 | 
            +
                calib_tensor = data['calib'].to(device=cuda)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                netG.filter(image_tensor)
         | 
| 84 | 
            +
                netC.filter(image_tensor)
         | 
| 85 | 
            +
                netC.attach(netG.get_im_feat())
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                b_min = data['b_min']
         | 
| 88 | 
            +
                b_max = data['b_max']
         | 
| 89 | 
            +
                try:
         | 
| 90 | 
            +
                    save_img_path = save_path[:-4] + '.png'
         | 
| 91 | 
            +
                    save_img_list = []
         | 
| 92 | 
            +
                    for v in range(image_tensor.shape[0]):
         | 
| 93 | 
            +
                        save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
         | 
| 94 | 
            +
                        save_img_list.append(save_img)
         | 
| 95 | 
            +
                    save_img = np.concatenate(save_img_list, axis=1)
         | 
| 96 | 
            +
                    Image.fromarray(np.uint8(save_img[:,:,::-1])).save(save_img_path)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    verts, faces, _, _ = reconstruction(
         | 
| 99 | 
            +
                        netG, cuda, calib_tensor, opt.resolution, b_min, b_max, use_octree=use_octree)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    # Now Getting colors
         | 
| 102 | 
            +
                    verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(device=cuda).float()
         | 
| 103 | 
            +
                    verts_tensor = reshape_sample_tensor(verts_tensor, opt.num_views)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    color = np.zeros(verts.shape)
         | 
| 106 | 
            +
                    interval = opt.num_sample_color
         | 
| 107 | 
            +
                    for i in range(len(color) // interval):
         | 
| 108 | 
            +
                        left = i * interval
         | 
| 109 | 
            +
                        right = i * interval + interval
         | 
| 110 | 
            +
                        if i == len(color) // interval - 1:
         | 
| 111 | 
            +
                            right = -1
         | 
| 112 | 
            +
                        netC.query(verts_tensor[:, :, left:right], calib_tensor)
         | 
| 113 | 
            +
                        rgb = netC.get_preds()[0].detach().cpu().numpy() * 0.5 + 0.5
         | 
| 114 | 
            +
                        color[left:right] = rgb.T
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    save_obj_mesh_with_color(save_path, verts, faces, color)
         | 
| 117 | 
            +
                except Exception as e:
         | 
| 118 | 
            +
                    print(e)
         | 
| 119 | 
            +
                    print('Can not create marching cubes at this time.')
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
         | 
| 122 | 
            +
                """Sets the learning rate to the initial LR decayed by schedule"""
         | 
| 123 | 
            +
                if epoch in schedule:
         | 
| 124 | 
            +
                    lr *= gamma
         | 
| 125 | 
            +
                    for param_group in optimizer.param_groups:
         | 
| 126 | 
            +
                        param_group['lr'] = lr
         | 
| 127 | 
            +
                return lr
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            def compute_acc(pred, gt, thresh=0.5):
         | 
| 131 | 
            +
                '''
         | 
| 132 | 
            +
                return:
         | 
| 133 | 
            +
                    IOU, precision, and recall
         | 
| 134 | 
            +
                '''
         | 
| 135 | 
            +
                with torch.no_grad():
         | 
| 136 | 
            +
                    vol_pred = pred > thresh
         | 
| 137 | 
            +
                    vol_gt = gt > thresh
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    union = vol_pred | vol_gt
         | 
| 140 | 
            +
                    inter = vol_pred & vol_gt
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    true_pos = inter.sum().float()
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    union = union.sum().float()
         | 
| 145 | 
            +
                    if union == 0:
         | 
| 146 | 
            +
                        union = 1
         | 
| 147 | 
            +
                    vol_pred = vol_pred.sum().float()
         | 
| 148 | 
            +
                    if vol_pred == 0:
         | 
| 149 | 
            +
                        vol_pred = 1
         | 
| 150 | 
            +
                    vol_gt = vol_gt.sum().float()
         | 
| 151 | 
            +
                    if vol_gt == 0:
         | 
| 152 | 
            +
                        vol_gt = 1
         | 
| 153 | 
            +
                    return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
         | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
            def calc_error(opt, net, cuda, dataset, num_tests):
         | 
| 157 | 
            +
                if num_tests > len(dataset):
         | 
| 158 | 
            +
                    num_tests = len(dataset)
         | 
| 159 | 
            +
                with torch.no_grad():
         | 
| 160 | 
            +
                    erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
         | 
| 161 | 
            +
                    for idx in tqdm(range(num_tests)):
         | 
| 162 | 
            +
                        data = dataset[idx * len(dataset) // num_tests]
         | 
| 163 | 
            +
                        # retrieve the data
         | 
| 164 | 
            +
                        image_tensor = data['img'].to(device=cuda)
         | 
| 165 | 
            +
                        calib_tensor = data['calib'].to(device=cuda)
         | 
| 166 | 
            +
                        sample_tensor = data['samples'].to(device=cuda).unsqueeze(0)
         | 
| 167 | 
            +
                        if opt.num_views > 1:
         | 
| 168 | 
            +
                            sample_tensor = reshape_sample_tensor(sample_tensor, opt.num_views)
         | 
| 169 | 
            +
                        label_tensor = data['labels'].to(device=cuda).unsqueeze(0)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                        res, error = net.forward(image_tensor, sample_tensor, calib_tensor, labels=label_tensor)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        IOU, prec, recall = compute_acc(res, label_tensor)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                        # print(
         | 
| 176 | 
            +
                        #     '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
         | 
| 177 | 
            +
                        #         .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
         | 
| 178 | 
            +
                        erorr_arr.append(error.item())
         | 
| 179 | 
            +
                        IOU_arr.append(IOU.item())
         | 
| 180 | 
            +
                        prec_arr.append(prec.item())
         | 
| 181 | 
            +
                        recall_arr.append(recall.item())
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                return np.average(erorr_arr), np.average(IOU_arr), np.average(prec_arr), np.average(recall_arr)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
         | 
| 186 | 
            +
                if num_tests > len(dataset):
         | 
| 187 | 
            +
                    num_tests = len(dataset)
         | 
| 188 | 
            +
                with torch.no_grad():
         | 
| 189 | 
            +
                    error_color_arr = []
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    for idx in tqdm(range(num_tests)):
         | 
| 192 | 
            +
                        data = dataset[idx * len(dataset) // num_tests]
         | 
| 193 | 
            +
                        # retrieve the data
         | 
| 194 | 
            +
                        image_tensor = data['img'].to(device=cuda)
         | 
| 195 | 
            +
                        calib_tensor = data['calib'].to(device=cuda)
         | 
| 196 | 
            +
                        color_sample_tensor = data['color_samples'].to(device=cuda).unsqueeze(0)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                        if opt.num_views > 1:
         | 
| 199 | 
            +
                            color_sample_tensor = reshape_sample_tensor(color_sample_tensor, opt.num_views)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                        rgb_tensor = data['rgbs'].to(device=cuda).unsqueeze(0)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                        netG.filter(image_tensor)
         | 
| 204 | 
            +
                        _, errorC = netC.forward(image_tensor, netG.get_im_feat(), color_sample_tensor, calib_tensor, labels=rgb_tensor)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                        # print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
         | 
| 207 | 
            +
                        #       .format(idx, num_tests, errorG.item(), errorC.item()))
         | 
| 208 | 
            +
                        error_color_arr.append(errorC.item())
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                return np.average(error_color_arr)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
         | 
| 214 | 
            +
                "3x3 convolution with padding"
         | 
| 215 | 
            +
                return nn.Conv2d(in_planes, out_planes, kernel_size=3,
         | 
| 216 | 
            +
                                 stride=strd, padding=padding, bias=bias)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
            def init_weights(net, init_type='normal', init_gain=0.02):
         | 
| 219 | 
            +
                """Initialize network weights.
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                Parameters:
         | 
| 222 | 
            +
                    net (network)   -- network to be initialized
         | 
| 223 | 
            +
                    init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
         | 
| 224 | 
            +
                    init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
         | 
| 227 | 
            +
                work better for some applications. Feel free to try yourself.
         | 
| 228 | 
            +
                """
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                def init_func(m):  # define the initialization function
         | 
| 231 | 
            +
                    classname = m.__class__.__name__
         | 
| 232 | 
            +
                    if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
         | 
| 233 | 
            +
                        if init_type == 'normal':
         | 
| 234 | 
            +
                            init.normal_(m.weight.data, 0.0, init_gain)
         | 
| 235 | 
            +
                        elif init_type == 'xavier':
         | 
| 236 | 
            +
                            init.xavier_normal_(m.weight.data, gain=init_gain)
         | 
| 237 | 
            +
                        elif init_type == 'kaiming':
         | 
| 238 | 
            +
                            init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
         | 
| 239 | 
            +
                        elif init_type == 'orthogonal':
         | 
| 240 | 
            +
                            init.orthogonal_(m.weight.data, gain=init_gain)
         | 
| 241 | 
            +
                        else:
         | 
| 242 | 
            +
                            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
         | 
| 243 | 
            +
                        if hasattr(m, 'bias') and m.bias is not None:
         | 
| 244 | 
            +
                            init.constant_(m.bias.data, 0.0)
         | 
| 245 | 
            +
                    elif classname.find(
         | 
| 246 | 
            +
                            'BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
         | 
| 247 | 
            +
                        init.normal_(m.weight.data, 1.0, init_gain)
         | 
| 248 | 
            +
                        init.constant_(m.bias.data, 0.0)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                print('initialize network with %s' % init_type)
         | 
| 251 | 
            +
                net.apply(init_func)  # apply the initialization function <init_func>
         | 
| 252 | 
            +
             | 
| 253 | 
            +
             | 
| 254 | 
            +
            def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
         | 
| 255 | 
            +
                """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
         | 
| 256 | 
            +
                Parameters:
         | 
| 257 | 
            +
                    net (network)      -- the network to be initialized
         | 
| 258 | 
            +
                    init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
         | 
| 259 | 
            +
                    gain (float)       -- scaling factor for normal, xavier and orthogonal.
         | 
| 260 | 
            +
                    gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                Return an initialized network.
         | 
| 263 | 
            +
                """
         | 
| 264 | 
            +
                if len(gpu_ids) > 0:
         | 
| 265 | 
            +
                    assert (torch.cuda.is_available())
         | 
| 266 | 
            +
                    net.to(gpu_ids[0])
         | 
| 267 | 
            +
                    net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
         | 
| 268 | 
            +
                init_weights(net, init_type, init_gain=init_gain)
         | 
| 269 | 
            +
                return net
         | 
| 270 | 
            +
             | 
| 271 | 
            +
             | 
| 272 | 
            +
            def imageSpaceRotation(xy, rot):
         | 
| 273 | 
            +
                '''
         | 
| 274 | 
            +
                args:
         | 
| 275 | 
            +
                    xy: (B, 2, N) input
         | 
| 276 | 
            +
                    rot: (B, 2) x,y axis rotation angles
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                rotation center will be always image center (other rotation center can be represented by additional z translation)
         | 
| 279 | 
            +
                '''
         | 
| 280 | 
            +
                disp = rot.unsqueeze(2).sin().expand_as(xy)
         | 
| 281 | 
            +
                return (disp * xy).sum(dim=1)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
             | 
| 284 | 
            +
            def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
         | 
| 285 | 
            +
                """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                Arguments:
         | 
| 288 | 
            +
                    netD (network)              -- discriminator network
         | 
| 289 | 
            +
                    real_data (tensor array)    -- real images
         | 
| 290 | 
            +
                    fake_data (tensor array)    -- generated images from the generator
         | 
| 291 | 
            +
                    device (str)                -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
         | 
| 292 | 
            +
                    type (str)                  -- if we mix real and fake data or not [real | fake | mixed].
         | 
| 293 | 
            +
                    constant (float)            -- the constant used in formula ( | |gradient||_2 - constant)^2
         | 
| 294 | 
            +
                    lambda_gp (float)           -- weight for this loss
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                Returns the gradient penalty loss
         | 
| 297 | 
            +
                """
         | 
| 298 | 
            +
                if lambda_gp > 0.0:
         | 
| 299 | 
            +
                    if type == 'real':  # either use real images, fake images, or a linear interpolation of two.
         | 
| 300 | 
            +
                        interpolatesv = real_data
         | 
| 301 | 
            +
                    elif type == 'fake':
         | 
| 302 | 
            +
                        interpolatesv = fake_data
         | 
| 303 | 
            +
                    elif type == 'mixed':
         | 
| 304 | 
            +
                        alpha = torch.rand(real_data.shape[0], 1)
         | 
| 305 | 
            +
                        alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(
         | 
| 306 | 
            +
                            *real_data.shape)
         | 
| 307 | 
            +
                        alpha = alpha.to(device)
         | 
| 308 | 
            +
                        interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
         | 
| 309 | 
            +
                    else:
         | 
| 310 | 
            +
                        raise NotImplementedError('{} not implemented'.format(type))
         | 
| 311 | 
            +
                    interpolatesv.requires_grad_(True)
         | 
| 312 | 
            +
                    disc_interpolates = netD(interpolatesv)
         | 
| 313 | 
            +
                    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
         | 
| 314 | 
            +
                                                    grad_outputs=torch.ones(disc_interpolates.size()).to(device),
         | 
| 315 | 
            +
                                                    create_graph=True, retain_graph=True, only_inputs=True)
         | 
| 316 | 
            +
                    gradients = gradients[0].view(real_data.size(0), -1)  # flat the data
         | 
| 317 | 
            +
                    gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp  # added eps
         | 
| 318 | 
            +
                    return gradient_penalty, gradients
         | 
| 319 | 
            +
                else:
         | 
| 320 | 
            +
                    return 0.0, None
         | 
| 321 | 
            +
             | 
| 322 | 
            +
            def get_norm_layer(norm_type='instance'):
         | 
| 323 | 
            +
                """Return a normalization layer
         | 
| 324 | 
            +
                Parameters:
         | 
| 325 | 
            +
                    norm_type (str) -- the name of the normalization layer: batch | instance | none
         | 
| 326 | 
            +
                For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
         | 
| 327 | 
            +
                For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
         | 
| 328 | 
            +
                """
         | 
| 329 | 
            +
                if norm_type == 'batch':
         | 
| 330 | 
            +
                    norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
         | 
| 331 | 
            +
                elif norm_type == 'instance':
         | 
| 332 | 
            +
                    norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
         | 
| 333 | 
            +
                elif norm_type == 'group':
         | 
| 334 | 
            +
                    norm_layer = functools.partial(nn.GroupNorm, 32)
         | 
| 335 | 
            +
                elif norm_type == 'none':
         | 
| 336 | 
            +
                    norm_layer = None
         | 
| 337 | 
            +
                else:
         | 
| 338 | 
            +
                    raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
         | 
| 339 | 
            +
                return norm_layer
         | 
| 340 | 
            +
             | 
| 341 | 
            +
            class Flatten(nn.Module):
         | 
| 342 | 
            +
                def forward(self, input):
         | 
| 343 | 
            +
                    return input.view(input.size(0), -1)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
            class ConvBlock(nn.Module):
         | 
| 346 | 
            +
                def __init__(self, in_planes, out_planes, norm='batch'):
         | 
| 347 | 
            +
                    super(ConvBlock, self).__init__()
         | 
| 348 | 
            +
                    self.conv1 = conv3x3(in_planes, int(out_planes / 2))
         | 
| 349 | 
            +
                    self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
         | 
| 350 | 
            +
                    self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    if norm == 'batch':
         | 
| 353 | 
            +
                        self.bn1 = nn.BatchNorm2d(in_planes)
         | 
| 354 | 
            +
                        self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
         | 
| 355 | 
            +
                        self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
         | 
| 356 | 
            +
                        self.bn4 = nn.BatchNorm2d(in_planes)
         | 
| 357 | 
            +
                    elif norm == 'group':
         | 
| 358 | 
            +
                        self.bn1 = nn.GroupNorm(32, in_planes)
         | 
| 359 | 
            +
                        self.bn2 = nn.GroupNorm(32, int(out_planes / 2))
         | 
| 360 | 
            +
                        self.bn3 = nn.GroupNorm(32, int(out_planes / 4))
         | 
| 361 | 
            +
                        self.bn4 = nn.GroupNorm(32, in_planes)
         | 
| 362 | 
            +
                    
         | 
| 363 | 
            +
                    if in_planes != out_planes:
         | 
| 364 | 
            +
                        self.downsample = nn.Sequential(
         | 
| 365 | 
            +
                            self.bn4,
         | 
| 366 | 
            +
                            nn.ReLU(True),
         | 
| 367 | 
            +
                            nn.Conv2d(in_planes, out_planes,
         | 
| 368 | 
            +
                                      kernel_size=1, stride=1, bias=False),
         | 
| 369 | 
            +
                        )
         | 
| 370 | 
            +
                    else:
         | 
| 371 | 
            +
                        self.downsample = None
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                def forward(self, x):
         | 
| 374 | 
            +
                    residual = x
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    out1 = self.bn1(x)
         | 
| 377 | 
            +
                    out1 = F.relu(out1, True)
         | 
| 378 | 
            +
                    out1 = self.conv1(out1)
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    out2 = self.bn2(out1)
         | 
| 381 | 
            +
                    out2 = F.relu(out2, True)
         | 
| 382 | 
            +
                    out2 = self.conv2(out2)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    out3 = self.bn3(out2)
         | 
| 385 | 
            +
                    out3 = F.relu(out3, True)
         | 
| 386 | 
            +
                    out3 = self.conv3(out3)
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    out3 = torch.cat((out1, out2, out3), 1)
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    if self.downsample is not None:
         | 
| 391 | 
            +
                        residual = self.downsample(residual)
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    out3 += residual
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    return out3
         | 
| 396 | 
            +
              
         | 
    	
        PIFu/lib/options.py
    ADDED
    
    | @@ -0,0 +1,157 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class BaseOptions():
         | 
| 6 | 
            +
                def __init__(self):
         | 
| 7 | 
            +
                    self.initialized = False
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                def initialize(self, parser):
         | 
| 10 | 
            +
                    # Datasets related
         | 
| 11 | 
            +
                    g_data = parser.add_argument_group('Data')
         | 
| 12 | 
            +
                    g_data.add_argument('--dataroot', type=str, default='./data',
         | 
| 13 | 
            +
                                        help='path to images (data folder)')
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                    g_data.add_argument('--loadSize', type=int, default=512, help='load size of input image')
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    # Experiment related
         | 
| 18 | 
            +
                    g_exp = parser.add_argument_group('Experiment')
         | 
| 19 | 
            +
                    g_exp.add_argument('--name', type=str, default='example',
         | 
| 20 | 
            +
                                       help='name of the experiment. It decides where to store samples and models')
         | 
| 21 | 
            +
                    g_exp.add_argument('--debug', action='store_true', help='debug mode or not')
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    g_exp.add_argument('--num_views', type=int, default=1, help='How many views to use for multiview network.')
         | 
| 24 | 
            +
                    g_exp.add_argument('--random_multiview', action='store_true', help='Select random multiview combination.')
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    # Training related
         | 
| 27 | 
            +
                    g_train = parser.add_argument_group('Training')
         | 
| 28 | 
            +
                    g_train.add_argument('--gpu_id', type=int, default=0, help='gpu id for cuda')
         | 
| 29 | 
            +
                    g_train.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2, -1 for CPU mode')
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    g_train.add_argument('--num_threads', default=1, type=int, help='# sthreads for loading data')
         | 
| 32 | 
            +
                    g_train.add_argument('--serial_batches', action='store_true',
         | 
| 33 | 
            +
                                         help='if true, takes images in order to make batches, otherwise takes them randomly')
         | 
| 34 | 
            +
                    g_train.add_argument('--pin_memory', action='store_true', help='pin_memory')
         | 
| 35 | 
            +
                    
         | 
| 36 | 
            +
                    g_train.add_argument('--batch_size', type=int, default=2, help='input batch size')
         | 
| 37 | 
            +
                    g_train.add_argument('--learning_rate', type=float, default=1e-3, help='adam learning rate')
         | 
| 38 | 
            +
                    g_train.add_argument('--learning_rateC', type=float, default=1e-3, help='adam learning rate')
         | 
| 39 | 
            +
                    g_train.add_argument('--num_epoch', type=int, default=100, help='num epoch to train')
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    g_train.add_argument('--freq_plot', type=int, default=10, help='freqency of the error plot')
         | 
| 42 | 
            +
                    g_train.add_argument('--freq_save', type=int, default=50, help='freqency of the save_checkpoints')
         | 
| 43 | 
            +
                    g_train.add_argument('--freq_save_ply', type=int, default=100, help='freqency of the save ply')
         | 
| 44 | 
            +
                   
         | 
| 45 | 
            +
                    g_train.add_argument('--no_gen_mesh', action='store_true')
         | 
| 46 | 
            +
                    g_train.add_argument('--no_num_eval', action='store_true')
         | 
| 47 | 
            +
                    
         | 
| 48 | 
            +
                    g_train.add_argument('--resume_epoch', type=int, default=-1, help='epoch resuming the training')
         | 
| 49 | 
            +
                    g_train.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    # Testing related
         | 
| 52 | 
            +
                    g_test = parser.add_argument_group('Testing')
         | 
| 53 | 
            +
                    g_test.add_argument('--resolution', type=int, default=256, help='# of grid in mesh reconstruction')
         | 
| 54 | 
            +
                    g_test.add_argument('--test_folder_path', type=str, default=None, help='the folder of test image')
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    # Sampling related
         | 
| 57 | 
            +
                    g_sample = parser.add_argument_group('Sampling')
         | 
| 58 | 
            +
                    g_sample.add_argument('--sigma', type=float, default=5.0, help='perturbation standard deviation for positions')
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    g_sample.add_argument('--num_sample_inout', type=int, default=5000, help='# of sampling points')
         | 
| 61 | 
            +
                    g_sample.add_argument('--num_sample_color', type=int, default=0, help='# of sampling points')
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    g_sample.add_argument('--z_size', type=float, default=200.0, help='z normalization factor')
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # Model related
         | 
| 66 | 
            +
                    g_model = parser.add_argument_group('Model')
         | 
| 67 | 
            +
                    # General
         | 
| 68 | 
            +
                    g_model.add_argument('--norm', type=str, default='group',
         | 
| 69 | 
            +
                                         help='instance normalization or batch normalization or group normalization')
         | 
| 70 | 
            +
                    g_model.add_argument('--norm_color', type=str, default='instance',
         | 
| 71 | 
            +
                                         help='instance normalization or batch normalization or group normalization')
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    # hg filter specify
         | 
| 74 | 
            +
                    g_model.add_argument('--num_stack', type=int, default=4, help='# of hourglass')
         | 
| 75 | 
            +
                    g_model.add_argument('--num_hourglass', type=int, default=2, help='# of stacked layer of hourglass')
         | 
| 76 | 
            +
                    g_model.add_argument('--skip_hourglass', action='store_true', help='skip connection in hourglass')
         | 
| 77 | 
            +
                    g_model.add_argument('--hg_down', type=str, default='ave_pool', help='ave pool || conv64 || conv128')
         | 
| 78 | 
            +
                    g_model.add_argument('--hourglass_dim', type=int, default='256', help='256 | 512')
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # Classification General
         | 
| 81 | 
            +
                    g_model.add_argument('--mlp_dim', nargs='+', default=[257, 1024, 512, 256, 128, 1], type=int,
         | 
| 82 | 
            +
                                         help='# of dimensions of mlp')
         | 
| 83 | 
            +
                    g_model.add_argument('--mlp_dim_color', nargs='+', default=[513, 1024, 512, 256, 128, 3],
         | 
| 84 | 
            +
                                         type=int, help='# of dimensions of color mlp')
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    g_model.add_argument('--use_tanh', action='store_true',
         | 
| 87 | 
            +
                                         help='using tanh after last conv of image_filter network')
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    # for train
         | 
| 90 | 
            +
                    parser.add_argument('--random_flip', action='store_true', help='if random flip')
         | 
| 91 | 
            +
                    parser.add_argument('--random_trans', action='store_true', help='if random flip')
         | 
| 92 | 
            +
                    parser.add_argument('--random_scale', action='store_true', help='if random flip')
         | 
| 93 | 
            +
                    parser.add_argument('--no_residual', action='store_true', help='no skip connection in mlp')
         | 
| 94 | 
            +
                    parser.add_argument('--schedule', type=int, nargs='+', default=[60, 80],
         | 
| 95 | 
            +
                                        help='Decrease learning rate at these epochs.')
         | 
| 96 | 
            +
                    parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
         | 
| 97 | 
            +
                    parser.add_argument('--color_loss_type', type=str, default='l1', help='mse | l1')
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    # for eval
         | 
| 100 | 
            +
                    parser.add_argument('--val_test_error', action='store_true', help='validate errors of test data')
         | 
| 101 | 
            +
                    parser.add_argument('--val_train_error', action='store_true', help='validate errors of train data')
         | 
| 102 | 
            +
                    parser.add_argument('--gen_test_mesh', action='store_true', help='generate test mesh')
         | 
| 103 | 
            +
                    parser.add_argument('--gen_train_mesh', action='store_true', help='generate train mesh')
         | 
| 104 | 
            +
                    parser.add_argument('--all_mesh', action='store_true', help='generate meshs from all hourglass output')
         | 
| 105 | 
            +
                    parser.add_argument('--num_gen_mesh_test', type=int, default=1,
         | 
| 106 | 
            +
                                        help='how many meshes to generate during testing')
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    # path
         | 
| 109 | 
            +
                    parser.add_argument('--checkpoints_path', type=str, default='./checkpoints', help='path to save checkpoints')
         | 
| 110 | 
            +
                    parser.add_argument('--load_netG_checkpoint_path', type=str, default=None, help='path to save checkpoints')
         | 
| 111 | 
            +
                    parser.add_argument('--load_netC_checkpoint_path', type=str, default=None, help='path to save checkpoints')
         | 
| 112 | 
            +
                    parser.add_argument('--results_path', type=str, default='./results', help='path to save results ply')
         | 
| 113 | 
            +
                    parser.add_argument('--load_checkpoint_path', type=str, help='path to save results ply')
         | 
| 114 | 
            +
                    parser.add_argument('--single', type=str, default='', help='single data for training')
         | 
| 115 | 
            +
                    # for single image reconstruction
         | 
| 116 | 
            +
                    parser.add_argument('--mask_path', type=str, help='path for input mask')
         | 
| 117 | 
            +
                    parser.add_argument('--img_path', type=str, help='path for input image')
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # aug
         | 
| 120 | 
            +
                    group_aug = parser.add_argument_group('aug')
         | 
| 121 | 
            +
                    group_aug.add_argument('--aug_alstd', type=float, default=0.0, help='augmentation pca lighting alpha std')
         | 
| 122 | 
            +
                    group_aug.add_argument('--aug_bri', type=float, default=0.0, help='augmentation brightness')
         | 
| 123 | 
            +
                    group_aug.add_argument('--aug_con', type=float, default=0.0, help='augmentation contrast')
         | 
| 124 | 
            +
                    group_aug.add_argument('--aug_sat', type=float, default=0.0, help='augmentation saturation')
         | 
| 125 | 
            +
                    group_aug.add_argument('--aug_hue', type=float, default=0.0, help='augmentation hue')
         | 
| 126 | 
            +
                    group_aug.add_argument('--aug_blur', type=float, default=0.0, help='augmentation blur')
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    # special tasks
         | 
| 129 | 
            +
                    self.initialized = True
         | 
| 130 | 
            +
                    return parser
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def gather_options(self):
         | 
| 133 | 
            +
                    # initialize parser with basic options
         | 
| 134 | 
            +
                    if not self.initialized:
         | 
| 135 | 
            +
                        parser = argparse.ArgumentParser(
         | 
| 136 | 
            +
                            formatter_class=argparse.ArgumentDefaultsHelpFormatter)
         | 
| 137 | 
            +
                        parser = self.initialize(parser)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    self.parser = parser
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    return parser.parse_args()
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                def print_options(self, opt):
         | 
| 144 | 
            +
                    message = ''
         | 
| 145 | 
            +
                    message += '----------------- Options ---------------\n'
         | 
| 146 | 
            +
                    for k, v in sorted(vars(opt).items()):
         | 
| 147 | 
            +
                        comment = ''
         | 
| 148 | 
            +
                        default = self.parser.get_default(k)
         | 
| 149 | 
            +
                        if v != default:
         | 
| 150 | 
            +
                            comment = '\t[default: %s]' % str(default)
         | 
| 151 | 
            +
                        message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
         | 
| 152 | 
            +
                    message += '----------------- End -------------------'
         | 
| 153 | 
            +
                    print(message)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def parse(self):
         | 
| 156 | 
            +
                    opt = self.gather_options()
         | 
| 157 | 
            +
                    return opt
         | 
    	
        PIFu/lib/renderer/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        PIFu/lib/renderer/camera.py
    ADDED
    
    | @@ -0,0 +1,207 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import cv2
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .glm import ortho
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class Camera:
         | 
| 8 | 
            +
                def __init__(self, width=1600, height=1200):
         | 
| 9 | 
            +
                    # Focal Length
         | 
| 10 | 
            +
                    # equivalent 50mm
         | 
| 11 | 
            +
                    focal = np.sqrt(width * width + height * height)
         | 
| 12 | 
            +
                    self.focal_x = focal
         | 
| 13 | 
            +
                    self.focal_y = focal
         | 
| 14 | 
            +
                    # Principal Point Offset
         | 
| 15 | 
            +
                    self.principal_x = width / 2
         | 
| 16 | 
            +
                    self.principal_y = height / 2
         | 
| 17 | 
            +
                    # Axis Skew
         | 
| 18 | 
            +
                    self.skew = 0
         | 
| 19 | 
            +
                    # Image Size
         | 
| 20 | 
            +
                    self.width = width
         | 
| 21 | 
            +
                    self.height = height
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    self.near = 1
         | 
| 24 | 
            +
                    self.far = 10
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    # Camera Center
         | 
| 27 | 
            +
                    self.center = np.array([0, 0, 1.6])
         | 
| 28 | 
            +
                    self.direction = np.array([0, 0, -1])
         | 
| 29 | 
            +
                    self.right = np.array([1, 0, 0])
         | 
| 30 | 
            +
                    self.up = np.array([0, 1, 0])
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    self.ortho_ratio = None
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def sanity_check(self):
         | 
| 35 | 
            +
                    self.center = self.center.reshape([-1])
         | 
| 36 | 
            +
                    self.direction = self.direction.reshape([-1])
         | 
| 37 | 
            +
                    self.right = self.right.reshape([-1])
         | 
| 38 | 
            +
                    self.up = self.up.reshape([-1])
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    assert len(self.center) == 3
         | 
| 41 | 
            +
                    assert len(self.direction) == 3
         | 
| 42 | 
            +
                    assert len(self.right) == 3
         | 
| 43 | 
            +
                    assert len(self.up) == 3
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                @staticmethod
         | 
| 46 | 
            +
                def normalize_vector(v):
         | 
| 47 | 
            +
                    v_norm = np.linalg.norm(v)
         | 
| 48 | 
            +
                    return v if v_norm == 0 else v / v_norm
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def get_real_z_value(self, z):
         | 
| 51 | 
            +
                    z_near = self.near
         | 
| 52 | 
            +
                    z_far = self.far
         | 
| 53 | 
            +
                    z_n = 2.0 * z - 1.0
         | 
| 54 | 
            +
                    z_e = 2.0 * z_near * z_far / (z_far + z_near - z_n * (z_far - z_near))
         | 
| 55 | 
            +
                    return z_e
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def get_rotation_matrix(self):
         | 
| 58 | 
            +
                    rot_mat = np.eye(3)
         | 
| 59 | 
            +
                    s = self.right
         | 
| 60 | 
            +
                    s = self.normalize_vector(s)
         | 
| 61 | 
            +
                    rot_mat[0, :] = s
         | 
| 62 | 
            +
                    u = self.up
         | 
| 63 | 
            +
                    u = self.normalize_vector(u)
         | 
| 64 | 
            +
                    rot_mat[1, :] = -u
         | 
| 65 | 
            +
                    rot_mat[2, :] = self.normalize_vector(self.direction)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    return rot_mat
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def get_translation_vector(self):
         | 
| 70 | 
            +
                    rot_mat = self.get_rotation_matrix()
         | 
| 71 | 
            +
                    trans = -np.dot(rot_mat, self.center)
         | 
| 72 | 
            +
                    return trans
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def get_intrinsic_matrix(self):
         | 
| 75 | 
            +
                    int_mat = np.eye(3)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    int_mat[0, 0] = self.focal_x
         | 
| 78 | 
            +
                    int_mat[1, 1] = self.focal_y
         | 
| 79 | 
            +
                    int_mat[0, 1] = self.skew
         | 
| 80 | 
            +
                    int_mat[0, 2] = self.principal_x
         | 
| 81 | 
            +
                    int_mat[1, 2] = self.principal_y
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    return int_mat
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def get_projection_matrix(self):
         | 
| 86 | 
            +
                    ext_mat = self.get_extrinsic_matrix()
         | 
| 87 | 
            +
                    int_mat = self.get_intrinsic_matrix()
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    return np.matmul(int_mat, ext_mat)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def get_extrinsic_matrix(self):
         | 
| 92 | 
            +
                    rot_mat = self.get_rotation_matrix()
         | 
| 93 | 
            +
                    int_mat = self.get_intrinsic_matrix()
         | 
| 94 | 
            +
                    trans = self.get_translation_vector()
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    extrinsic = np.eye(4)
         | 
| 97 | 
            +
                    extrinsic[:3, :3] = rot_mat
         | 
| 98 | 
            +
                    extrinsic[:3, 3] = trans
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    return extrinsic[:3, :]
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def set_rotation_matrix(self, rot_mat):
         | 
| 103 | 
            +
                    self.direction = rot_mat[2, :]
         | 
| 104 | 
            +
                    self.up = -rot_mat[1, :]
         | 
| 105 | 
            +
                    self.right = rot_mat[0, :]
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def set_intrinsic_matrix(self, int_mat):
         | 
| 108 | 
            +
                    self.focal_x = int_mat[0, 0]
         | 
| 109 | 
            +
                    self.focal_y = int_mat[1, 1]
         | 
| 110 | 
            +
                    self.skew = int_mat[0, 1]
         | 
| 111 | 
            +
                    self.principal_x = int_mat[0, 2]
         | 
| 112 | 
            +
                    self.principal_y = int_mat[1, 2]
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def set_projection_matrix(self, proj_mat):
         | 
| 115 | 
            +
                    res = cv2.decomposeProjectionMatrix(proj_mat)
         | 
| 116 | 
            +
                    int_mat, rot_mat, camera_center_homo = res[0], res[1], res[2]
         | 
| 117 | 
            +
                    camera_center = camera_center_homo[0:3] / camera_center_homo[3]
         | 
| 118 | 
            +
                    camera_center = camera_center.reshape(-1)
         | 
| 119 | 
            +
                    int_mat = int_mat / int_mat[2][2]
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    self.set_intrinsic_matrix(int_mat)
         | 
| 122 | 
            +
                    self.set_rotation_matrix(rot_mat)
         | 
| 123 | 
            +
                    self.center = camera_center
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    self.sanity_check()
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def get_gl_matrix(self):
         | 
| 128 | 
            +
                    z_near = self.near
         | 
| 129 | 
            +
                    z_far = self.far
         | 
| 130 | 
            +
                    rot_mat = self.get_rotation_matrix()
         | 
| 131 | 
            +
                    int_mat = self.get_intrinsic_matrix()
         | 
| 132 | 
            +
                    trans = self.get_translation_vector()
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    extrinsic = np.eye(4)
         | 
| 135 | 
            +
                    extrinsic[:3, :3] = rot_mat
         | 
| 136 | 
            +
                    extrinsic[:3, 3] = trans
         | 
| 137 | 
            +
                    axis_adj = np.eye(4)
         | 
| 138 | 
            +
                    axis_adj[2, 2] = -1
         | 
| 139 | 
            +
                    axis_adj[1, 1] = -1
         | 
| 140 | 
            +
                    model_view = np.matmul(axis_adj, extrinsic)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    projective = np.zeros([4, 4])
         | 
| 143 | 
            +
                    projective[:2, :2] = int_mat[:2, :2]
         | 
| 144 | 
            +
                    projective[:2, 2:3] = -int_mat[:2, 2:3]
         | 
| 145 | 
            +
                    projective[3, 2] = -1
         | 
| 146 | 
            +
                    projective[2, 2] = (z_near + z_far)
         | 
| 147 | 
            +
                    projective[2, 3] = (z_near * z_far)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    if self.ortho_ratio is None:
         | 
| 150 | 
            +
                        ndc = ortho(0, self.width, 0, self.height, z_near, z_far)
         | 
| 151 | 
            +
                        perspective = np.matmul(ndc, projective)
         | 
| 152 | 
            +
                    else:
         | 
| 153 | 
            +
                        perspective = ortho(-self.width * self.ortho_ratio / 2, self.width * self.ortho_ratio / 2,
         | 
| 154 | 
            +
                                            -self.height * self.ortho_ratio / 2, self.height * self.ortho_ratio / 2,
         | 
| 155 | 
            +
                                            z_near, z_far)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    return perspective, model_view
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            def KRT_from_P(proj_mat, normalize_K=True):
         | 
| 161 | 
            +
                res = cv2.decomposeProjectionMatrix(proj_mat)
         | 
| 162 | 
            +
                K, Rot, camera_center_homog = res[0], res[1], res[2]
         | 
| 163 | 
            +
                camera_center = camera_center_homog[0:3] / camera_center_homog[3]
         | 
| 164 | 
            +
                trans = -Rot.dot(camera_center)
         | 
| 165 | 
            +
                if normalize_K:
         | 
| 166 | 
            +
                    K = K / K[2][2]
         | 
| 167 | 
            +
                return K, Rot, trans
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            def MVP_from_P(proj_mat, width, height, near=0.1, far=10000):
         | 
| 171 | 
            +
                '''
         | 
| 172 | 
            +
                Convert OpenCV camera calibration matrix to OpenGL projection and model view matrix
         | 
| 173 | 
            +
                :param proj_mat: OpenCV camera projeciton matrix
         | 
| 174 | 
            +
                :param width: Image width
         | 
| 175 | 
            +
                :param height: Image height
         | 
| 176 | 
            +
                :param near: Z near value
         | 
| 177 | 
            +
                :param far: Z far value
         | 
| 178 | 
            +
                :return: OpenGL projection matrix and model view matrix
         | 
| 179 | 
            +
                '''
         | 
| 180 | 
            +
                res = cv2.decomposeProjectionMatrix(proj_mat)
         | 
| 181 | 
            +
                K, Rot, camera_center_homog = res[0], res[1], res[2]
         | 
| 182 | 
            +
                camera_center = camera_center_homog[0:3] / camera_center_homog[3]
         | 
| 183 | 
            +
                trans = -Rot.dot(camera_center)
         | 
| 184 | 
            +
                K = K / K[2][2]
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                extrinsic = np.eye(4)
         | 
| 187 | 
            +
                extrinsic[:3, :3] = Rot
         | 
| 188 | 
            +
                extrinsic[:3, 3:4] = trans
         | 
| 189 | 
            +
                axis_adj = np.eye(4)
         | 
| 190 | 
            +
                axis_adj[2, 2] = -1
         | 
| 191 | 
            +
                axis_adj[1, 1] = -1
         | 
| 192 | 
            +
                model_view = np.matmul(axis_adj, extrinsic)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                zFar = far
         | 
| 195 | 
            +
                zNear = near
         | 
| 196 | 
            +
                projective = np.zeros([4, 4])
         | 
| 197 | 
            +
                projective[:2, :2] = K[:2, :2]
         | 
| 198 | 
            +
                projective[:2, 2:3] = -K[:2, 2:3]
         | 
| 199 | 
            +
                projective[3, 2] = -1
         | 
| 200 | 
            +
                projective[2, 2] = (zNear + zFar)
         | 
| 201 | 
            +
                projective[2, 3] = (zNear * zFar)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                ndc = ortho(0, width, 0, height, zNear, zFar)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                perspective = np.matmul(ndc, projective)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                return perspective, model_view
         | 
    	
        PIFu/lib/renderer/gl/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        PIFu/lib/renderer/gl/cam_render.py
    ADDED
    
    | @@ -0,0 +1,48 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .render import Render
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            GLUT = None
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            class CamRender(Render):
         | 
| 6 | 
            +
                def __init__(self, width=1600, height=1200, name='Cam Renderer',
         | 
| 7 | 
            +
                             program_files=['simple.fs', 'simple.vs'], color_size=1, ms_rate=1, egl=False):
         | 
| 8 | 
            +
                    Render.__init__(self, width, height, name, program_files, color_size, ms_rate=ms_rate, egl=egl)
         | 
| 9 | 
            +
                    self.camera = None
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                    if not egl:
         | 
| 12 | 
            +
                        global GLUT
         | 
| 13 | 
            +
                        import OpenGL.GLUT as GLUT
         | 
| 14 | 
            +
                        GLUT.glutDisplayFunc(self.display)
         | 
| 15 | 
            +
                        GLUT.glutKeyboardFunc(self.keyboard)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def set_camera(self, camera):
         | 
| 18 | 
            +
                    self.camera = camera
         | 
| 19 | 
            +
                    self.projection_matrix, self.model_view_matrix = camera.get_gl_matrix()
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def keyboard(self, key, x, y):
         | 
| 22 | 
            +
                    # up
         | 
| 23 | 
            +
                    eps = 1
         | 
| 24 | 
            +
                    # print(key)
         | 
| 25 | 
            +
                    if key == b'w':
         | 
| 26 | 
            +
                        self.camera.center += eps * self.camera.direction
         | 
| 27 | 
            +
                    elif key == b's':
         | 
| 28 | 
            +
                        self.camera.center -= eps * self.camera.direction
         | 
| 29 | 
            +
                    if key == b'a':
         | 
| 30 | 
            +
                        self.camera.center -= eps * self.camera.right
         | 
| 31 | 
            +
                    elif key == b'd':
         | 
| 32 | 
            +
                        self.camera.center += eps * self.camera.right
         | 
| 33 | 
            +
                    if key == b' ':
         | 
| 34 | 
            +
                        self.camera.center += eps * self.camera.up
         | 
| 35 | 
            +
                    elif key == b'x':
         | 
| 36 | 
            +
                        self.camera.center -= eps * self.camera.up
         | 
| 37 | 
            +
                    elif key == b'i':
         | 
| 38 | 
            +
                        self.camera.near += 0.1 * eps
         | 
| 39 | 
            +
                        self.camera.far += 0.1 * eps
         | 
| 40 | 
            +
                    elif key == b'o':
         | 
| 41 | 
            +
                        self.camera.near -= 0.1 * eps
         | 
| 42 | 
            +
                        self.camera.far -= 0.1 * eps
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    self.projection_matrix, self.model_view_matrix = self.camera.get_gl_matrix()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def show(self):
         | 
| 47 | 
            +
                    if GLUT is not None:
         | 
| 48 | 
            +
                        GLUT.glutMainLoop()
         | 
    	
        PIFu/lib/renderer/gl/data/prt.fs
    ADDED
    
    | @@ -0,0 +1,153 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #version 330
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            uniform vec3 SHCoeffs[9];
         | 
| 4 | 
            +
            uniform uint analytic;
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            uniform uint hasNormalMap;
         | 
| 7 | 
            +
            uniform uint hasAlbedoMap;
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            uniform sampler2D AlbedoMap;
         | 
| 10 | 
            +
            uniform sampler2D NormalMap;
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            in VertexData {
         | 
| 13 | 
            +
                vec3 Position;
         | 
| 14 | 
            +
                vec3 Depth;
         | 
| 15 | 
            +
                vec3 ModelNormal;
         | 
| 16 | 
            +
                vec2 Texcoord;
         | 
| 17 | 
            +
                vec3 Tangent;
         | 
| 18 | 
            +
                vec3 Bitangent;
         | 
| 19 | 
            +
                vec3 PRT1;
         | 
| 20 | 
            +
                vec3 PRT2;
         | 
| 21 | 
            +
                vec3 PRT3;
         | 
| 22 | 
            +
            } VertexIn;
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            layout (location = 0) out vec4 FragColor;
         | 
| 25 | 
            +
            layout (location = 1) out vec4 FragNormal;
         | 
| 26 | 
            +
            layout (location = 2) out vec4 FragPosition;
         | 
| 27 | 
            +
            layout (location = 3) out vec4 FragAlbedo;
         | 
| 28 | 
            +
            layout (location = 4) out vec4 FragShading;
         | 
| 29 | 
            +
            layout (location = 5) out vec4 FragPRT1;
         | 
| 30 | 
            +
            layout (location = 6) out vec4 FragPRT2;
         | 
| 31 | 
            +
            layout (location = 7) out vec4 FragPRT3;
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            vec4 gammaCorrection(vec4 vec, float g)
         | 
| 34 | 
            +
            {
         | 
| 35 | 
            +
                return vec4(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g), vec.w);
         | 
| 36 | 
            +
            }
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            vec3 gammaCorrection(vec3 vec, float g)
         | 
| 39 | 
            +
            {
         | 
| 40 | 
            +
                return vec3(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g));
         | 
| 41 | 
            +
            }
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            void evaluateH(vec3 n, out float H[9])
         | 
| 44 | 
            +
            {
         | 
| 45 | 
            +
                float c1 = 0.429043, c2 = 0.511664,
         | 
| 46 | 
            +
                    c3 = 0.743125, c4 = 0.886227, c5 = 0.247708;
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                H[0] = c4;
         | 
| 49 | 
            +
                H[1] = 2.0 * c2 * n[1];
         | 
| 50 | 
            +
                H[2] = 2.0 * c2 * n[2];
         | 
| 51 | 
            +
                H[3] = 2.0 * c2 * n[0];
         | 
| 52 | 
            +
                H[4] = 2.0 * c1 * n[0] * n[1];
         | 
| 53 | 
            +
                H[5] = 2.0 * c1 * n[1] * n[2];
         | 
| 54 | 
            +
                H[6] = c3 * n[2] * n[2] - c5;
         | 
| 55 | 
            +
                H[7] = 2.0 * c1 * n[2] * n[0];
         | 
| 56 | 
            +
                H[8] = c1 * (n[0] * n[0] - n[1] * n[1]);
         | 
| 57 | 
            +
            }
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            vec3 evaluateLightingModel(vec3 normal)
         | 
| 60 | 
            +
            {
         | 
| 61 | 
            +
                float H[9];
         | 
| 62 | 
            +
                evaluateH(normal, H);
         | 
| 63 | 
            +
                vec3 res = vec3(0.0);
         | 
| 64 | 
            +
                for (int i = 0; i < 9; i++) {
         | 
| 65 | 
            +
                    res += H[i] * SHCoeffs[i];
         | 
| 66 | 
            +
                }
         | 
| 67 | 
            +
                return res;
         | 
| 68 | 
            +
            }
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            // nC: coarse geometry normal, nH: fine normal from normal map
         | 
| 71 | 
            +
            vec3 evaluateLightingModelHybrid(vec3 nC, vec3 nH, mat3 prt)
         | 
| 72 | 
            +
            {
         | 
| 73 | 
            +
                float HC[9], HH[9];
         | 
| 74 | 
            +
                evaluateH(nC, HC);
         | 
| 75 | 
            +
                evaluateH(nH, HH);
         | 
| 76 | 
            +
                
         | 
| 77 | 
            +
                vec3 res = vec3(0.0);
         | 
| 78 | 
            +
                vec3 shadow = vec3(0.0);
         | 
| 79 | 
            +
                vec3 unshadow = vec3(0.0);
         | 
| 80 | 
            +
                for(int i = 0; i < 3; ++i){
         | 
| 81 | 
            +
                    for(int j = 0; j < 3; ++j){
         | 
| 82 | 
            +
                        int id = i*3+j;
         | 
| 83 | 
            +
                        res += HH[id]* SHCoeffs[id];
         | 
| 84 | 
            +
                        shadow += prt[i][j] * SHCoeffs[id];
         | 
| 85 | 
            +
                        unshadow += HC[id] * SHCoeffs[id];
         | 
| 86 | 
            +
                    }
         | 
| 87 | 
            +
                }
         | 
| 88 | 
            +
                vec3 ratio = clamp(shadow/unshadow,0.0,1.0);
         | 
| 89 | 
            +
                res = ratio * res;
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                return res;
         | 
| 92 | 
            +
            }
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            vec3 evaluateLightingModelPRT(mat3 prt)
         | 
| 95 | 
            +
            {
         | 
| 96 | 
            +
                vec3 res = vec3(0.0);
         | 
| 97 | 
            +
                for(int i = 0; i < 3; ++i){
         | 
| 98 | 
            +
                    for(int j = 0; j < 3; ++j){
         | 
| 99 | 
            +
                        res += prt[i][j] * SHCoeffs[i*3+j];
         | 
| 100 | 
            +
                    }
         | 
| 101 | 
            +
                }
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                return res;
         | 
| 104 | 
            +
            }
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            void main()
         | 
| 107 | 
            +
            {
         | 
| 108 | 
            +
                vec2 uv = VertexIn.Texcoord;
         | 
| 109 | 
            +
                vec3 nC = normalize(VertexIn.ModelNormal);
         | 
| 110 | 
            +
                vec3 nml = nC;
         | 
| 111 | 
            +
                mat3 prt = mat3(VertexIn.PRT1, VertexIn.PRT2, VertexIn.PRT3);
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                if(hasAlbedoMap == uint(0))
         | 
| 114 | 
            +
                    FragAlbedo = vec4(1.0);
         | 
| 115 | 
            +
                else
         | 
| 116 | 
            +
                    FragAlbedo = texture(AlbedoMap, uv);//gammaCorrection(texture(AlbedoMap, uv), 1.0/2.2);
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                if(hasNormalMap == uint(0))
         | 
| 119 | 
            +
                {
         | 
| 120 | 
            +
                    if(analytic == uint(0))
         | 
| 121 | 
            +
                        FragShading = vec4(evaluateLightingModelPRT(prt), 1.0f);
         | 
| 122 | 
            +
                    else
         | 
| 123 | 
            +
                        FragShading = vec4(evaluateLightingModel(nC), 1.0f);
         | 
| 124 | 
            +
                }
         | 
| 125 | 
            +
                else
         | 
| 126 | 
            +
                {
         | 
| 127 | 
            +
                    vec3 n_tan = normalize(texture(NormalMap, uv).rgb*2.0-vec3(1.0));
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    mat3 TBN = mat3(normalize(VertexIn.Tangent),normalize(VertexIn.Bitangent),nC);
         | 
| 130 | 
            +
                    vec3 nH = normalize(TBN * n_tan);
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    if(analytic == uint(0))
         | 
| 133 | 
            +
                        FragShading = vec4(evaluateLightingModelHybrid(nC,nH,prt),1.0f);
         | 
| 134 | 
            +
                    else
         | 
| 135 | 
            +
                        FragShading = vec4(evaluateLightingModel(nH), 1.0f);
         | 
| 136 | 
            +
                
         | 
| 137 | 
            +
                    nml = nH;
         | 
| 138 | 
            +
                }
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                FragShading = gammaCorrection(FragShading, 2.2);
         | 
| 141 | 
            +
                FragColor = clamp(FragAlbedo * FragShading, 0.0, 1.0);
         | 
| 142 | 
            +
                FragNormal = vec4(0.5*(nml+vec3(1.0)), 1.0);
         | 
| 143 | 
            +
                FragPosition = vec4(VertexIn.Position,VertexIn.Depth.x);
         | 
| 144 | 
            +
                FragShading = vec4(clamp(0.5*FragShading.xyz, 0.0, 1.0),1.0);
         | 
| 145 | 
            +
                // FragColor = gammaCorrection(clamp(FragAlbedo * FragShading, 0.0, 1.0),2.2);
         | 
| 146 | 
            +
                // FragNormal = vec4(0.5*(nml+vec3(1.0)), 1.0);
         | 
| 147 | 
            +
                // FragPosition = vec4(VertexIn.Position,VertexIn.Depth.x);
         | 
| 148 | 
            +
                // FragShading = vec4(gammaCorrection(clamp(0.5*FragShading.xyz, 0.0, 1.0),2.2),1.0);
         | 
| 149 | 
            +
                // FragAlbedo = gammaCorrection(FragAlbedo,2.2);
         | 
| 150 | 
            +
                FragPRT1 = vec4(VertexIn.PRT1,1.0);
         | 
| 151 | 
            +
                FragPRT2 = vec4(VertexIn.PRT2,1.0);
         | 
| 152 | 
            +
                FragPRT3 = vec4(VertexIn.PRT3,1.0);
         | 
| 153 | 
            +
            }
         | 
    	
        PIFu/lib/renderer/gl/data/prt.vs
    ADDED
    
    | @@ -0,0 +1,167 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #version 330
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            layout (location = 0) in vec3 a_Position;
         | 
| 4 | 
            +
            layout (location = 1) in vec3 a_Normal;
         | 
| 5 | 
            +
            layout (location = 2) in vec2 a_TextureCoord;
         | 
| 6 | 
            +
            layout (location = 3) in vec3 a_Tangent;
         | 
| 7 | 
            +
            layout (location = 4) in vec3 a_Bitangent;
         | 
| 8 | 
            +
            layout (location = 5) in vec3 a_PRT1;
         | 
| 9 | 
            +
            layout (location = 6) in vec3 a_PRT2;
         | 
| 10 | 
            +
            layout (location = 7) in vec3 a_PRT3;
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            out VertexData {
         | 
| 13 | 
            +
                vec3 Position;
         | 
| 14 | 
            +
                vec3 Depth;
         | 
| 15 | 
            +
                vec3 ModelNormal;
         | 
| 16 | 
            +
                vec2 Texcoord;
         | 
| 17 | 
            +
                vec3 Tangent;
         | 
| 18 | 
            +
                vec3 Bitangent;
         | 
| 19 | 
            +
                vec3 PRT1;
         | 
| 20 | 
            +
                vec3 PRT2;
         | 
| 21 | 
            +
                vec3 PRT3;
         | 
| 22 | 
            +
            } VertexOut;
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            uniform mat3 RotMat;
         | 
| 25 | 
            +
            uniform mat4 NormMat;
         | 
| 26 | 
            +
            uniform mat4 ModelMat;
         | 
| 27 | 
            +
            uniform mat4 PerspMat;
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            float s_c3 = 0.94617469575; // (3*sqrt(5))/(4*sqrt(pi))
         | 
| 30 | 
            +
            float s_c4 = -0.31539156525;// (-sqrt(5))/(4*sqrt(pi))
         | 
| 31 | 
            +
            float s_c5 = 0.54627421529; // (sqrt(15))/(4*sqrt(pi))
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            float s_c_scale = 1.0/0.91529123286551084;
         | 
| 34 | 
            +
            float s_c_scale_inv = 0.91529123286551084;
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            float s_rc2 = 1.5853309190550713*s_c_scale;
         | 
| 37 | 
            +
            float s_c4_div_c3 = s_c4/s_c3;
         | 
| 38 | 
            +
            float s_c4_div_c3_x2 = (s_c4/s_c3)*2.0;
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            float s_scale_dst2 = s_c3 * s_c_scale_inv;
         | 
| 41 | 
            +
            float s_scale_dst4 = s_c5 * s_c_scale_inv;
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            void OptRotateBand0(float x[1], mat3 R, out float dst[1])
         | 
| 44 | 
            +
            {
         | 
| 45 | 
            +
                dst[0] = x[0];
         | 
| 46 | 
            +
            }
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            // 9 multiplies
         | 
| 49 | 
            +
            void OptRotateBand1(float x[3], mat3 R, out float dst[3])
         | 
| 50 | 
            +
            {
         | 
| 51 | 
            +
                // derived from  SlowRotateBand1
         | 
| 52 | 
            +
                dst[0] = ( R[1][1])*x[0] + (-R[1][2])*x[1] + ( R[1][0])*x[2];
         | 
| 53 | 
            +
                dst[1] = (-R[2][1])*x[0] + ( R[2][2])*x[1] + (-R[2][0])*x[2];
         | 
| 54 | 
            +
                dst[2] = ( R[0][1])*x[0] + (-R[0][2])*x[1] + ( R[0][0])*x[2];
         | 
| 55 | 
            +
            }
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            // 48 multiplies
         | 
| 58 | 
            +
            void OptRotateBand2(float x[5], mat3 R, out float dst[5])
         | 
| 59 | 
            +
            {
         | 
| 60 | 
            +
                // Sparse matrix multiply
         | 
| 61 | 
            +
                float sh0 =  x[3] + x[4] + x[4] - x[1];
         | 
| 62 | 
            +
                float sh1 =  x[0] + s_rc2*x[2] +  x[3] + x[4];
         | 
| 63 | 
            +
                float sh2 =  x[0];
         | 
| 64 | 
            +
                float sh3 = -x[3];
         | 
| 65 | 
            +
                float sh4 = -x[1];
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
                // Rotations.  R0 and R1 just use the raw matrix columns
         | 
| 68 | 
            +
                float r2x = R[0][0] + R[0][1];
         | 
| 69 | 
            +
                float r2y = R[1][0] + R[1][1];
         | 
| 70 | 
            +
                float r2z = R[2][0] + R[2][1];
         | 
| 71 | 
            +
                
         | 
| 72 | 
            +
                float r3x = R[0][0] + R[0][2];
         | 
| 73 | 
            +
                float r3y = R[1][0] + R[1][2];
         | 
| 74 | 
            +
                float r3z = R[2][0] + R[2][2];
         | 
| 75 | 
            +
                
         | 
| 76 | 
            +
                float r4x = R[0][1] + R[0][2];
         | 
| 77 | 
            +
                float r4y = R[1][1] + R[1][2];
         | 
| 78 | 
            +
                float r4z = R[2][1] + R[2][2];
         | 
| 79 | 
            +
                
         | 
| 80 | 
            +
                // dense matrix multiplication one column at a time
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                // column 0
         | 
| 83 | 
            +
                float sh0_x = sh0 * R[0][0];
         | 
| 84 | 
            +
                float sh0_y = sh0 * R[1][0];
         | 
| 85 | 
            +
                float d0 = sh0_x * R[1][0];
         | 
| 86 | 
            +
                float d1 = sh0_y * R[2][0];
         | 
| 87 | 
            +
                float d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3);
         | 
| 88 | 
            +
                float d3 = sh0_x * R[2][0];
         | 
| 89 | 
            +
                float d4 = sh0_x * R[0][0] - sh0_y * R[1][0];
         | 
| 90 | 
            +
                
         | 
| 91 | 
            +
                // column 1
         | 
| 92 | 
            +
                float sh1_x = sh1 * R[0][2];
         | 
| 93 | 
            +
                float sh1_y = sh1 * R[1][2];
         | 
| 94 | 
            +
                d0 += sh1_x * R[1][2];
         | 
| 95 | 
            +
                d1 += sh1_y * R[2][2];
         | 
| 96 | 
            +
                d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3);
         | 
| 97 | 
            +
                d3 += sh1_x * R[2][2];
         | 
| 98 | 
            +
                d4 += sh1_x * R[0][2] - sh1_y * R[1][2];
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                // column 2
         | 
| 101 | 
            +
                float sh2_x = sh2 * r2x;
         | 
| 102 | 
            +
                float sh2_y = sh2 * r2y;
         | 
| 103 | 
            +
                d0 += sh2_x * r2y;
         | 
| 104 | 
            +
                d1 += sh2_y * r2z;
         | 
| 105 | 
            +
                d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2);
         | 
| 106 | 
            +
                d3 += sh2_x * r2z;
         | 
| 107 | 
            +
                d4 += sh2_x * r2x - sh2_y * r2y;
         | 
| 108 | 
            +
                
         | 
| 109 | 
            +
                // column 3
         | 
| 110 | 
            +
                float sh3_x = sh3 * r3x;
         | 
| 111 | 
            +
                float sh3_y = sh3 * r3y;
         | 
| 112 | 
            +
                d0 += sh3_x * r3y;
         | 
| 113 | 
            +
                d1 += sh3_y * r3z;
         | 
| 114 | 
            +
                d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2);
         | 
| 115 | 
            +
                d3 += sh3_x * r3z;
         | 
| 116 | 
            +
                d4 += sh3_x * r3x - sh3_y * r3y;
         | 
| 117 | 
            +
                
         | 
| 118 | 
            +
                // column 4
         | 
| 119 | 
            +
                float sh4_x = sh4 * r4x;
         | 
| 120 | 
            +
                float sh4_y = sh4 * r4y;
         | 
| 121 | 
            +
                d0 += sh4_x * r4y;
         | 
| 122 | 
            +
                d1 += sh4_y * r4z;
         | 
| 123 | 
            +
                d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2);
         | 
| 124 | 
            +
                d3 += sh4_x * r4z;
         | 
| 125 | 
            +
                d4 += sh4_x * r4x - sh4_y * r4y;
         | 
| 126 | 
            +
                
         | 
| 127 | 
            +
                // extra multipliers
         | 
| 128 | 
            +
                dst[0] = d0;
         | 
| 129 | 
            +
                dst[1] = -d1;
         | 
| 130 | 
            +
                dst[2] = d2 * s_scale_dst2;
         | 
| 131 | 
            +
                dst[3] = -d3;
         | 
| 132 | 
            +
                dst[4] = d4 * s_scale_dst4;
         | 
| 133 | 
            +
            }
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            void main()
         | 
| 136 | 
            +
            {
         | 
| 137 | 
            +
                // normalization
         | 
| 138 | 
            +
                vec3 pos = (NormMat * vec4(a_Position,1.0)).xyz;
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                mat3 R = mat3(ModelMat) * RotMat;
         | 
| 141 | 
            +
                VertexOut.ModelNormal = (R * a_Normal);
         | 
| 142 | 
            +
                VertexOut.Position = R * pos;
         | 
| 143 | 
            +
                VertexOut.Texcoord = a_TextureCoord;
         | 
| 144 | 
            +
                VertexOut.Tangent = (R * a_Tangent);
         | 
| 145 | 
            +
                VertexOut.Bitangent = (R * a_Bitangent);
         | 
| 146 | 
            +
                float PRT0, PRT1[3], PRT2[5];
         | 
| 147 | 
            +
                PRT0 = a_PRT1[0];
         | 
| 148 | 
            +
                PRT1[0] = a_PRT1[1];
         | 
| 149 | 
            +
                PRT1[1] = a_PRT1[2];
         | 
| 150 | 
            +
                PRT1[2] = a_PRT2[0];
         | 
| 151 | 
            +
                PRT2[0] = a_PRT2[1];
         | 
| 152 | 
            +
                PRT2[1] = a_PRT2[2];
         | 
| 153 | 
            +
                PRT2[2] = a_PRT3[0];
         | 
| 154 | 
            +
                PRT2[3] = a_PRT3[1];
         | 
| 155 | 
            +
                PRT2[4] = a_PRT3[2];
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                OptRotateBand1(PRT1, R, PRT1);
         | 
| 158 | 
            +
                OptRotateBand2(PRT2, R, PRT2);
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                VertexOut.PRT1 = vec3(PRT0,PRT1[0],PRT1[1]);
         | 
| 161 | 
            +
                VertexOut.PRT2 = vec3(PRT1[2],PRT2[0],PRT2[1]);
         | 
| 162 | 
            +
                VertexOut.PRT3 = vec3(PRT2[2],PRT2[3],PRT2[4]);
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                gl_Position = PerspMat * ModelMat * vec4(RotMat * pos, 1.0);
         | 
| 165 | 
            +
                
         | 
| 166 | 
            +
                VertexOut.Depth = vec3(gl_Position.z / gl_Position.w);
         | 
| 167 | 
            +
            }
         | 
    	
        PIFu/lib/renderer/gl/data/prt_uv.fs
    ADDED
    
    | @@ -0,0 +1,141 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #version 330
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            uniform vec3 SHCoeffs[9];
         | 
| 4 | 
            +
            uniform uint analytic;
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            uniform uint hasNormalMap;
         | 
| 7 | 
            +
            uniform uint hasAlbedoMap;
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            uniform sampler2D AlbedoMap;
         | 
| 10 | 
            +
            uniform sampler2D NormalMap;
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            in VertexData {
         | 
| 13 | 
            +
                vec3 Position;
         | 
| 14 | 
            +
                vec3 ModelNormal;
         | 
| 15 | 
            +
                vec3 CameraNormal;
         | 
| 16 | 
            +
                vec2 Texcoord;
         | 
| 17 | 
            +
                vec3 Tangent;
         | 
| 18 | 
            +
                vec3 Bitangent;
         | 
| 19 | 
            +
                vec3 PRT1;
         | 
| 20 | 
            +
                vec3 PRT2;
         | 
| 21 | 
            +
                vec3 PRT3;
         | 
| 22 | 
            +
            } VertexIn;
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            layout (location = 0) out vec4 FragColor;
         | 
| 25 | 
            +
            layout (location = 1) out vec4 FragPosition;
         | 
| 26 | 
            +
            layout (location = 2) out vec4 FragNormal;
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            vec4 gammaCorrection(vec4 vec, float g)
         | 
| 29 | 
            +
            {
         | 
| 30 | 
            +
                return vec4(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g), vec.w);
         | 
| 31 | 
            +
            }
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            vec3 gammaCorrection(vec3 vec, float g)
         | 
| 34 | 
            +
            {
         | 
| 35 | 
            +
                return vec3(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g));
         | 
| 36 | 
            +
            }
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            void evaluateH(vec3 n, out float H[9])
         | 
| 39 | 
            +
            {
         | 
| 40 | 
            +
                float c1 = 0.429043, c2 = 0.511664,
         | 
| 41 | 
            +
                    c3 = 0.743125, c4 = 0.886227, c5 = 0.247708;
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                H[0] = c4;
         | 
| 44 | 
            +
                H[1] = 2.0 * c2 * n[1];
         | 
| 45 | 
            +
                H[2] = 2.0 * c2 * n[2];
         | 
| 46 | 
            +
                H[3] = 2.0 * c2 * n[0];
         | 
| 47 | 
            +
                H[4] = 2.0 * c1 * n[0] * n[1];
         | 
| 48 | 
            +
                H[5] = 2.0 * c1 * n[1] * n[2];
         | 
| 49 | 
            +
                H[6] = c3 * n[2] * n[2] - c5;
         | 
| 50 | 
            +
                H[7] = 2.0 * c1 * n[2] * n[0];
         | 
| 51 | 
            +
                H[8] = c1 * (n[0] * n[0] - n[1] * n[1]);
         | 
| 52 | 
            +
            }
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            vec3 evaluateLightingModel(vec3 normal)
         | 
| 55 | 
            +
            {
         | 
| 56 | 
            +
                float H[9];
         | 
| 57 | 
            +
                evaluateH(normal, H);
         | 
| 58 | 
            +
                vec3 res = vec3(0.0);
         | 
| 59 | 
            +
                for (int i = 0; i < 9; i++) {
         | 
| 60 | 
            +
                    res += H[i] * SHCoeffs[i];
         | 
| 61 | 
            +
                }
         | 
| 62 | 
            +
                return res;
         | 
| 63 | 
            +
            }
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            // nC: coarse geometry normal, nH: fine normal from normal map
         | 
| 66 | 
            +
            vec3 evaluateLightingModelHybrid(vec3 nC, vec3 nH, mat3 prt)
         | 
| 67 | 
            +
            {
         | 
| 68 | 
            +
                float HC[9], HH[9];
         | 
| 69 | 
            +
                evaluateH(nC, HC);
         | 
| 70 | 
            +
                evaluateH(nH, HH);
         | 
| 71 | 
            +
                
         | 
| 72 | 
            +
                vec3 res = vec3(0.0);
         | 
| 73 | 
            +
                vec3 shadow = vec3(0.0);
         | 
| 74 | 
            +
                vec3 unshadow = vec3(0.0);
         | 
| 75 | 
            +
                for(int i = 0; i < 3; ++i){
         | 
| 76 | 
            +
                    for(int j = 0; j < 3; ++j){
         | 
| 77 | 
            +
                        int id = i*3+j;
         | 
| 78 | 
            +
                        res += HH[id]* SHCoeffs[id];
         | 
| 79 | 
            +
                        shadow += prt[i][j] * SHCoeffs[id];
         | 
| 80 | 
            +
                        unshadow += HC[id] * SHCoeffs[id];
         | 
| 81 | 
            +
                    }
         | 
| 82 | 
            +
                }
         | 
| 83 | 
            +
                vec3 ratio = clamp(shadow/unshadow,0.0,1.0);
         | 
| 84 | 
            +
                res = ratio * res;
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                return res;
         | 
| 87 | 
            +
            }
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            vec3 evaluateLightingModelPRT(mat3 prt)
         | 
| 90 | 
            +
            {
         | 
| 91 | 
            +
                vec3 res = vec3(0.0);
         | 
| 92 | 
            +
                for(int i = 0; i < 3; ++i){
         | 
| 93 | 
            +
                    for(int j = 0; j < 3; ++j){
         | 
| 94 | 
            +
                        res += prt[i][j] * SHCoeffs[i*3+j];
         | 
| 95 | 
            +
                    }
         | 
| 96 | 
            +
                }
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                return res;
         | 
| 99 | 
            +
            }
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            void main()
         | 
| 102 | 
            +
            {
         | 
| 103 | 
            +
                vec2 uv = VertexIn.Texcoord;
         | 
| 104 | 
            +
                vec3 nM = normalize(VertexIn.ModelNormal);
         | 
| 105 | 
            +
                vec3 nC = normalize(VertexIn.CameraNormal);
         | 
| 106 | 
            +
                vec3 nml = nC;
         | 
| 107 | 
            +
                mat3 prt = mat3(VertexIn.PRT1, VertexIn.PRT2, VertexIn.PRT3);
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                vec4 albedo, shading;
         | 
| 110 | 
            +
                if(hasAlbedoMap == uint(0))
         | 
| 111 | 
            +
                    albedo = vec4(1.0);
         | 
| 112 | 
            +
                else
         | 
| 113 | 
            +
                    albedo = texture(AlbedoMap, uv);//gammaCorrection(texture(AlbedoMap, uv), 1.0/2.2);
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                if(hasNormalMap == uint(0))
         | 
| 116 | 
            +
                {
         | 
| 117 | 
            +
                    if(analytic == uint(0))
         | 
| 118 | 
            +
                        shading = vec4(evaluateLightingModelPRT(prt), 1.0f);
         | 
| 119 | 
            +
                    else
         | 
| 120 | 
            +
                        shading = vec4(evaluateLightingModel(nC), 1.0f);
         | 
| 121 | 
            +
                }
         | 
| 122 | 
            +
                else
         | 
| 123 | 
            +
                {
         | 
| 124 | 
            +
                    vec3 n_tan = normalize(texture(NormalMap, uv).rgb*2.0-vec3(1.0));
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    mat3 TBN = mat3(normalize(VertexIn.Tangent),normalize(VertexIn.Bitangent),nC);
         | 
| 127 | 
            +
                    vec3 nH = normalize(TBN * n_tan);
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    if(analytic == uint(0))
         | 
| 130 | 
            +
                        shading = vec4(evaluateLightingModelHybrid(nC,nH,prt),1.0f);
         | 
| 131 | 
            +
                    else
         | 
| 132 | 
            +
                        shading = vec4(evaluateLightingModel(nH), 1.0f);
         | 
| 133 | 
            +
                
         | 
| 134 | 
            +
                    nml = nH;
         | 
| 135 | 
            +
                }
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                shading = gammaCorrection(shading, 2.2);
         | 
| 138 | 
            +
                FragColor = clamp(albedo * shading, 0.0, 1.0);
         | 
| 139 | 
            +
                FragPosition = vec4(VertexIn.Position,1.0);
         | 
| 140 | 
            +
                FragNormal = vec4(0.5*(nM+vec3(1.0)),1.0);
         | 
| 141 | 
            +
            }
         | 
    	
        PIFu/lib/renderer/gl/data/prt_uv.vs
    ADDED
    
    | @@ -0,0 +1,168 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #version 330
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            layout (location = 0) in vec3 a_Position;
         | 
| 4 | 
            +
            layout (location = 1) in vec3 a_Normal;
         | 
| 5 | 
            +
            layout (location = 2) in vec2 a_TextureCoord;
         | 
| 6 | 
            +
            layout (location = 3) in vec3 a_Tangent;
         | 
| 7 | 
            +
            layout (location = 4) in vec3 a_Bitangent;
         | 
| 8 | 
            +
            layout (location = 5) in vec3 a_PRT1;
         | 
| 9 | 
            +
            layout (location = 6) in vec3 a_PRT2;
         | 
| 10 | 
            +
            layout (location = 7) in vec3 a_PRT3;
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            out VertexData {
         | 
| 13 | 
            +
                vec3 Position;
         | 
| 14 | 
            +
                vec3 ModelNormal;
         | 
| 15 | 
            +
                vec3 CameraNormal;
         | 
| 16 | 
            +
                vec2 Texcoord;
         | 
| 17 | 
            +
                vec3 Tangent;
         | 
| 18 | 
            +
                vec3 Bitangent;
         | 
| 19 | 
            +
                vec3 PRT1;
         | 
| 20 | 
            +
                vec3 PRT2;
         | 
| 21 | 
            +
                vec3 PRT3;
         | 
| 22 | 
            +
            } VertexOut;
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            uniform mat3 RotMat;
         | 
| 25 | 
            +
            uniform mat4 NormMat;
         | 
| 26 | 
            +
            uniform mat4 ModelMat;
         | 
| 27 | 
            +
            uniform mat4 PerspMat;
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            #define pi 3.1415926535897932384626433832795
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            float s_c3 = 0.94617469575; // (3*sqrt(5))/(4*sqrt(pi))
         | 
| 32 | 
            +
            float s_c4 = -0.31539156525;// (-sqrt(5))/(4*sqrt(pi))
         | 
| 33 | 
            +
            float s_c5 = 0.54627421529; // (sqrt(15))/(4*sqrt(pi))
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            float s_c_scale = 1.0/0.91529123286551084;
         | 
| 36 | 
            +
            float s_c_scale_inv = 0.91529123286551084;
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            float s_rc2 = 1.5853309190550713*s_c_scale;
         | 
| 39 | 
            +
            float s_c4_div_c3 = s_c4/s_c3;
         | 
| 40 | 
            +
            float s_c4_div_c3_x2 = (s_c4/s_c3)*2.0;
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            float s_scale_dst2 = s_c3 * s_c_scale_inv;
         | 
| 43 | 
            +
            float s_scale_dst4 = s_c5 * s_c_scale_inv;
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            void OptRotateBand0(float x[1], mat3 R, out float dst[1])
         | 
| 46 | 
            +
            {
         | 
| 47 | 
            +
                dst[0] = x[0];
         | 
| 48 | 
            +
            }
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            // 9 multiplies
         | 
| 51 | 
            +
            void OptRotateBand1(float x[3], mat3 R, out float dst[3])
         | 
| 52 | 
            +
            {
         | 
| 53 | 
            +
                // derived from  SlowRotateBand1
         | 
| 54 | 
            +
                dst[0] = ( R[1][1])*x[0] + (-R[1][2])*x[1] + ( R[1][0])*x[2];
         | 
| 55 | 
            +
                dst[1] = (-R[2][1])*x[0] + ( R[2][2])*x[1] + (-R[2][0])*x[2];
         | 
| 56 | 
            +
                dst[2] = ( R[0][1])*x[0] + (-R[0][2])*x[1] + ( R[0][0])*x[2];
         | 
| 57 | 
            +
            }
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            // 48 multiplies
         | 
| 60 | 
            +
            void OptRotateBand2(float x[5], mat3 R, out float dst[5])
         | 
| 61 | 
            +
            {
         | 
| 62 | 
            +
                // Sparse matrix multiply
         | 
| 63 | 
            +
                float sh0 =  x[3] + x[4] + x[4] - x[1];
         | 
| 64 | 
            +
                float sh1 =  x[0] + s_rc2*x[2] +  x[3] + x[4];
         | 
| 65 | 
            +
                float sh2 =  x[0];
         | 
| 66 | 
            +
                float sh3 = -x[3];
         | 
| 67 | 
            +
                float sh4 = -x[1];
         | 
| 68 | 
            +
                
         | 
| 69 | 
            +
                // Rotations.  R0 and R1 just use the raw matrix columns
         | 
| 70 | 
            +
                float r2x = R[0][0] + R[0][1];
         | 
| 71 | 
            +
                float r2y = R[1][0] + R[1][1];
         | 
| 72 | 
            +
                float r2z = R[2][0] + R[2][1];
         | 
| 73 | 
            +
                
         | 
| 74 | 
            +
                float r3x = R[0][0] + R[0][2];
         | 
| 75 | 
            +
                float r3y = R[1][0] + R[1][2];
         | 
| 76 | 
            +
                float r3z = R[2][0] + R[2][2];
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                float r4x = R[0][1] + R[0][2];
         | 
| 79 | 
            +
                float r4y = R[1][1] + R[1][2];
         | 
| 80 | 
            +
                float r4z = R[2][1] + R[2][2];
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                // dense matrix multiplication one column at a time
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
                // column 0
         | 
| 85 | 
            +
                float sh0_x = sh0 * R[0][0];
         | 
| 86 | 
            +
                float sh0_y = sh0 * R[1][0];
         | 
| 87 | 
            +
                float d0 = sh0_x * R[1][0];
         | 
| 88 | 
            +
                float d1 = sh0_y * R[2][0];
         | 
| 89 | 
            +
                float d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3);
         | 
| 90 | 
            +
                float d3 = sh0_x * R[2][0];
         | 
| 91 | 
            +
                float d4 = sh0_x * R[0][0] - sh0_y * R[1][0];
         | 
| 92 | 
            +
                
         | 
| 93 | 
            +
                // column 1
         | 
| 94 | 
            +
                float sh1_x = sh1 * R[0][2];
         | 
| 95 | 
            +
                float sh1_y = sh1 * R[1][2];
         | 
| 96 | 
            +
                d0 += sh1_x * R[1][2];
         | 
| 97 | 
            +
                d1 += sh1_y * R[2][2];
         | 
| 98 | 
            +
                d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3);
         | 
| 99 | 
            +
                d3 += sh1_x * R[2][2];
         | 
| 100 | 
            +
                d4 += sh1_x * R[0][2] - sh1_y * R[1][2];
         | 
| 101 | 
            +
                
         | 
| 102 | 
            +
                // column 2
         | 
| 103 | 
            +
                float sh2_x = sh2 * r2x;
         | 
| 104 | 
            +
                float sh2_y = sh2 * r2y;
         | 
| 105 | 
            +
                d0 += sh2_x * r2y;
         | 
| 106 | 
            +
                d1 += sh2_y * r2z;
         | 
| 107 | 
            +
                d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2);
         | 
| 108 | 
            +
                d3 += sh2_x * r2z;
         | 
| 109 | 
            +
                d4 += sh2_x * r2x - sh2_y * r2y;
         | 
| 110 | 
            +
                
         | 
| 111 | 
            +
                // column 3
         | 
| 112 | 
            +
                float sh3_x = sh3 * r3x;
         | 
| 113 | 
            +
                float sh3_y = sh3 * r3y;
         | 
| 114 | 
            +
                d0 += sh3_x * r3y;
         | 
| 115 | 
            +
                d1 += sh3_y * r3z;
         | 
| 116 | 
            +
                d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2);
         | 
| 117 | 
            +
                d3 += sh3_x * r3z;
         | 
| 118 | 
            +
                d4 += sh3_x * r3x - sh3_y * r3y;
         | 
| 119 | 
            +
                
         | 
| 120 | 
            +
                // column 4
         | 
| 121 | 
            +
                float sh4_x = sh4 * r4x;
         | 
| 122 | 
            +
                float sh4_y = sh4 * r4y;
         | 
| 123 | 
            +
                d0 += sh4_x * r4y;
         | 
| 124 | 
            +
                d1 += sh4_y * r4z;
         | 
| 125 | 
            +
                d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2);
         | 
| 126 | 
            +
                d3 += sh4_x * r4z;
         | 
| 127 | 
            +
                d4 += sh4_x * r4x - sh4_y * r4y;
         | 
| 128 | 
            +
                
         | 
| 129 | 
            +
                // extra multipliers
         | 
| 130 | 
            +
                dst[0] = d0;
         | 
| 131 | 
            +
                dst[1] = -d1;
         | 
| 132 | 
            +
                dst[2] = d2 * s_scale_dst2;
         | 
| 133 | 
            +
                dst[3] = -d3;
         | 
| 134 | 
            +
                dst[4] = d4 * s_scale_dst4;
         | 
| 135 | 
            +
            }
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            void main()
         | 
| 138 | 
            +
            { 
         | 
| 139 | 
            +
                // normalization
         | 
| 140 | 
            +
                mat3 R = mat3(ModelMat) * RotMat;
         | 
| 141 | 
            +
                VertexOut.ModelNormal = a_Normal;
         | 
| 142 | 
            +
                VertexOut.CameraNormal = (R * a_Normal);
         | 
| 143 | 
            +
                VertexOut.Position = a_Position;
         | 
| 144 | 
            +
                VertexOut.Texcoord = a_TextureCoord;
         | 
| 145 | 
            +
                VertexOut.Tangent = (R * a_Tangent);
         | 
| 146 | 
            +
                VertexOut.Bitangent = (R * a_Bitangent);
         | 
| 147 | 
            +
                float PRT0, PRT1[3], PRT2[5];
         | 
| 148 | 
            +
                PRT0 = a_PRT1[0];
         | 
| 149 | 
            +
                PRT1[0] = a_PRT1[1];
         | 
| 150 | 
            +
                PRT1[1] = a_PRT1[2];
         | 
| 151 | 
            +
                PRT1[2] = a_PRT2[0];
         | 
| 152 | 
            +
                PRT2[0] = a_PRT2[1];
         | 
| 153 | 
            +
                PRT2[1] = a_PRT2[2];
         | 
| 154 | 
            +
                PRT2[2] = a_PRT3[0];
         | 
| 155 | 
            +
                PRT2[3] = a_PRT3[1];
         | 
| 156 | 
            +
                PRT2[4] = a_PRT3[2];
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                OptRotateBand1(PRT1, R, PRT1);
         | 
| 159 | 
            +
                OptRotateBand2(PRT2, R, PRT2);
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                VertexOut.PRT1 = vec3(PRT0,PRT1[0],PRT1[1]);
         | 
| 162 | 
            +
                VertexOut.PRT2 = vec3(PRT1[2],PRT2[0],PRT2[1]);
         | 
| 163 | 
            +
                VertexOut.PRT3 = vec3(PRT2[2],PRT2[3],PRT2[4]);
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                gl_Position = vec4(a_TextureCoord, 0.0, 1.0) - vec4(0.5, 0.5, 0, 0);
         | 
| 166 | 
            +
                gl_Position[0] *= 2.0;
         | 
| 167 | 
            +
                gl_Position[1] *= 2.0;
         | 
| 168 | 
            +
            }
         | 
    	
        PIFu/lib/renderer/gl/data/quad.fs
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #version 330 core
         | 
| 2 | 
            +
            out vec4 FragColor;
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            in vec2 TexCoord;
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            uniform sampler2D screenTexture;
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            void main()
         | 
| 9 | 
            +
            {
         | 
| 10 | 
            +
                FragColor = texture(screenTexture, TexCoord);
         | 
| 11 | 
            +
            }
         | 
    	
        PIFu/lib/renderer/gl/data/quad.vs
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #version 330 core
         | 
| 2 | 
            +
            layout (location = 0) in vec2 aPos;
         | 
| 3 | 
            +
            layout (location = 1) in vec2 aTexCoord;
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            out vec2 TexCoord;
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            void main()
         | 
| 8 | 
            +
            {
         | 
| 9 | 
            +
                gl_Position = vec4(aPos.x, aPos.y, 0.0, 1.0);
         | 
| 10 | 
            +
                TexCoord = aTexCoord;
         | 
| 11 | 
            +
            }
         | 
    	
        PIFu/lib/renderer/gl/framework.py
    ADDED
    
    | @@ -0,0 +1,90 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Mario Rosasco, 2016
         | 
| 2 | 
            +
            # adapted from framework.cpp, Copyright (C) 2010-2012 by Jason L. McKesson
         | 
| 3 | 
            +
            # This file is licensed under the MIT License.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # NB: Unlike in the framework.cpp organization, the main loop is contained
         | 
| 6 | 
            +
            # in the tutorial files, not in this framework file. Additionally, a copy of
         | 
| 7 | 
            +
            # this module file must exist in the same directory as the tutorial files
         | 
| 8 | 
            +
            # to be imported properly.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
            from OpenGL.GL import *
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Function that creates and compiles shaders according to the given type (a GL enum value) and
         | 
| 14 | 
            +
            # shader program (a file containing a GLSL program).
         | 
| 15 | 
            +
            def loadShader(shaderType, shaderFile):
         | 
| 16 | 
            +
                # check if file exists, get full path name
         | 
| 17 | 
            +
                strFilename = findFileOrThrow(shaderFile)
         | 
| 18 | 
            +
                shaderData = None
         | 
| 19 | 
            +
                with open(strFilename, 'r') as f:
         | 
| 20 | 
            +
                    shaderData = f.read()
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                shader = glCreateShader(shaderType)
         | 
| 23 | 
            +
                glShaderSource(shader, shaderData)  # note that this is a simpler function call than in C
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                # This shader compilation is more explicit than the one used in
         | 
| 26 | 
            +
                # framework.cpp, which relies on a glutil wrapper function.
         | 
| 27 | 
            +
                # This is made explicit here mainly to decrease dependence on pyOpenGL
         | 
| 28 | 
            +
                # utilities and wrappers, which docs caution may change in future versions.
         | 
| 29 | 
            +
                glCompileShader(shader)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                status = glGetShaderiv(shader, GL_COMPILE_STATUS)
         | 
| 32 | 
            +
                if status == GL_FALSE:
         | 
| 33 | 
            +
                    # Note that getting the error log is much simpler in Python than in C/C++
         | 
| 34 | 
            +
                    # and does not require explicit handling of the string buffer
         | 
| 35 | 
            +
                    strInfoLog = glGetShaderInfoLog(shader)
         | 
| 36 | 
            +
                    strShaderType = ""
         | 
| 37 | 
            +
                    if shaderType is GL_VERTEX_SHADER:
         | 
| 38 | 
            +
                        strShaderType = "vertex"
         | 
| 39 | 
            +
                    elif shaderType is GL_GEOMETRY_SHADER:
         | 
| 40 | 
            +
                        strShaderType = "geometry"
         | 
| 41 | 
            +
                    elif shaderType is GL_FRAGMENT_SHADER:
         | 
| 42 | 
            +
                        strShaderType = "fragment"
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    print("Compilation failure for " + strShaderType + " shader:\n" + str(strInfoLog))
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                return shader
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            # Function that accepts a list of shaders, compiles them, and returns a handle to the compiled program
         | 
| 50 | 
            +
            def createProgram(shaderList):
         | 
| 51 | 
            +
                program = glCreateProgram()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                for shader in shaderList:
         | 
| 54 | 
            +
                    glAttachShader(program, shader)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                glLinkProgram(program)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                status = glGetProgramiv(program, GL_LINK_STATUS)
         | 
| 59 | 
            +
                if status == GL_FALSE:
         | 
| 60 | 
            +
                    # Note that getting the error log is much simpler in Python than in C/C++
         | 
| 61 | 
            +
                    # and does not require explicit handling of the string buffer
         | 
| 62 | 
            +
                    strInfoLog = glGetProgramInfoLog(program)
         | 
| 63 | 
            +
                    print("Linker failure: \n" + str(strInfoLog))
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                for shader in shaderList:
         | 
| 66 | 
            +
                    glDetachShader(program, shader)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                return program
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            # Helper function to locate and open the target file (passed in as a string).
         | 
| 72 | 
            +
            # Returns the full path to the file as a string.
         | 
| 73 | 
            +
            def findFileOrThrow(strBasename):
         | 
| 74 | 
            +
                # Keep constant names in C-style convention, for readability
         | 
| 75 | 
            +
                # when comparing to C(/C++) code.
         | 
| 76 | 
            +
                if os.path.isfile(strBasename):
         | 
| 77 | 
            +
                    return strBasename
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                LOCAL_FILE_DIR = "data" + os.sep
         | 
| 80 | 
            +
                GLOBAL_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) + os.sep + "data" + os.sep
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                strFilename = LOCAL_FILE_DIR + strBasename
         | 
| 83 | 
            +
                if os.path.isfile(strFilename):
         | 
| 84 | 
            +
                    return strFilename
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                strFilename = GLOBAL_FILE_DIR + strBasename
         | 
| 87 | 
            +
                if os.path.isfile(strFilename):
         | 
| 88 | 
            +
                    return strFilename
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                raise IOError('Could not find target file ' + strBasename)
         | 
    	
        PIFu/lib/renderer/gl/glcontext.py
    ADDED
    
    | @@ -0,0 +1,142 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Headless GPU-accelerated OpenGL context creation on Google Colaboratory.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Typical usage:
         | 
| 4 | 
            +
             | 
| 5 | 
            +
                # Optional PyOpenGL configuratiopn can be done here.
         | 
| 6 | 
            +
                # import OpenGL
         | 
| 7 | 
            +
                # OpenGL.ERROR_CHECKING = True
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                # 'glcontext' must be imported before any OpenGL.* API.
         | 
| 10 | 
            +
                from lucid.misc.gl.glcontext import create_opengl_context
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                # Now it's safe to import OpenGL and EGL functions
         | 
| 13 | 
            +
                import OpenGL.GL as gl
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                # create_opengl_context() creates a GL context that is attached to an
         | 
| 16 | 
            +
                # offscreen surface of the specified size. Note that rendering to buffers
         | 
| 17 | 
            +
                # of other sizes and formats is still possible with OpenGL Framebuffers.
         | 
| 18 | 
            +
                #
         | 
| 19 | 
            +
                # Users are expected to directly use the EGL API in case more advanced
         | 
| 20 | 
            +
                # context management is required.
         | 
| 21 | 
            +
                width, height = 640, 480
         | 
| 22 | 
            +
                create_opengl_context((width, height))
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                # OpenGL context is available here.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            """
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            from __future__ import print_function
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # pylint: disable=unused-import,g-import-not-at-top,g-statement-before-imports
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            try:
         | 
| 33 | 
            +
              import OpenGL
         | 
| 34 | 
            +
            except:
         | 
| 35 | 
            +
              print('This module depends on PyOpenGL.')
         | 
| 36 | 
            +
              print('Please run "\033[1m!pip install -q pyopengl\033[0m" '
         | 
| 37 | 
            +
                    'prior importing this module.')
         | 
| 38 | 
            +
              raise
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            import ctypes
         | 
| 41 | 
            +
            from ctypes import pointer, util
         | 
| 42 | 
            +
            import os
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            os.environ['PYOPENGL_PLATFORM'] = 'egl'
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            # OpenGL loading workaround.
         | 
| 47 | 
            +
            #
         | 
| 48 | 
            +
            # * PyOpenGL tries to load libGL, but we need libOpenGL, see [1,2].
         | 
| 49 | 
            +
            #   This could have been solved by a symlink libGL->libOpenGL, but:
         | 
| 50 | 
            +
            #
         | 
| 51 | 
            +
            # * Python 2.7 can't find libGL and linEGL due to a bug (see [3])
         | 
| 52 | 
            +
            #   in ctypes.util, that was only wixed in Python 3.6.
         | 
| 53 | 
            +
            #
         | 
| 54 | 
            +
            # So, the only solution I've found is to monkeypatch ctypes.util
         | 
| 55 | 
            +
            # [1] https://devblogs.nvidia.com/egl-eye-opengl-visualization-without-x-server/
         | 
| 56 | 
            +
            # [2] https://devblogs.nvidia.com/linking-opengl-server-side-rendering/
         | 
| 57 | 
            +
            # [3] https://bugs.python.org/issue9998
         | 
| 58 | 
            +
            _find_library_old = ctypes.util.find_library
         | 
| 59 | 
            +
            try:
         | 
| 60 | 
            +
             | 
| 61 | 
            +
              def _find_library_new(name):
         | 
| 62 | 
            +
                return {
         | 
| 63 | 
            +
                    'GL': 'libOpenGL.so',
         | 
| 64 | 
            +
                    'EGL': 'libEGL.so',
         | 
| 65 | 
            +
                }.get(name, _find_library_old(name))
         | 
| 66 | 
            +
              util.find_library = _find_library_new
         | 
| 67 | 
            +
              import OpenGL.GL as gl
         | 
| 68 | 
            +
              import OpenGL.EGL as egl
         | 
| 69 | 
            +
              from OpenGL import error
         | 
| 70 | 
            +
              from OpenGL.EGL.EXT.device_base import egl_get_devices
         | 
| 71 | 
            +
              from OpenGL.raw.EGL.EXT.platform_device import EGL_PLATFORM_DEVICE_EXT
         | 
| 72 | 
            +
            except:
         | 
| 73 | 
            +
              print('Unable to load OpenGL libraries. '
         | 
| 74 | 
            +
                    'Make sure you use GPU-enabled backend.')
         | 
| 75 | 
            +
              print('Press "Runtime->Change runtime type" and set '
         | 
| 76 | 
            +
                    '"Hardware accelerator" to GPU.')
         | 
| 77 | 
            +
              raise
         | 
| 78 | 
            +
            finally:
         | 
| 79 | 
            +
              util.find_library = _find_library_old
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            def create_initialized_headless_egl_display():
         | 
| 82 | 
            +
              """Creates an initialized EGL display directly on a device."""
         | 
| 83 | 
            +
              for device in egl_get_devices():
         | 
| 84 | 
            +
                display = egl.eglGetPlatformDisplayEXT(EGL_PLATFORM_DEVICE_EXT, device, None)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                if display != egl.EGL_NO_DISPLAY and egl.eglGetError() == egl.EGL_SUCCESS:
         | 
| 87 | 
            +
                  # `eglInitialize` may or may not raise an exception on failure depending
         | 
| 88 | 
            +
                  # on how PyOpenGL is configured. We therefore catch a `GLError` and also
         | 
| 89 | 
            +
                  # manually check the output of `eglGetError()` here.
         | 
| 90 | 
            +
                  try:
         | 
| 91 | 
            +
                    initialized = egl.eglInitialize(display, None, None)
         | 
| 92 | 
            +
                  except error.GLError:
         | 
| 93 | 
            +
                    pass
         | 
| 94 | 
            +
                  else:
         | 
| 95 | 
            +
                    if initialized == egl.EGL_TRUE and egl.eglGetError() == egl.EGL_SUCCESS:
         | 
| 96 | 
            +
                      return display
         | 
| 97 | 
            +
              return egl.EGL_NO_DISPLAY
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            def create_opengl_context(surface_size=(640, 480)):
         | 
| 100 | 
            +
              """Create offscreen OpenGL context and make it current.
         | 
| 101 | 
            +
             | 
| 102 | 
            +
              Users are expected to directly use EGL API in case more advanced
         | 
| 103 | 
            +
              context management is required.
         | 
| 104 | 
            +
             | 
| 105 | 
            +
              Args:
         | 
| 106 | 
            +
                surface_size: (width, height), size of the offscreen rendering surface.
         | 
| 107 | 
            +
              """
         | 
| 108 | 
            +
              egl_display = create_initialized_headless_egl_display()
         | 
| 109 | 
            +
              if egl_display == egl.EGL_NO_DISPLAY:
         | 
| 110 | 
            +
                raise ImportError('Cannot initialize a headless EGL display.')
         | 
| 111 | 
            +
             | 
| 112 | 
            +
              major, minor = egl.EGLint(), egl.EGLint()
         | 
| 113 | 
            +
              egl.eglInitialize(egl_display, pointer(major), pointer(minor))
         | 
| 114 | 
            +
             | 
| 115 | 
            +
              config_attribs = [
         | 
| 116 | 
            +
                  egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT, egl.EGL_BLUE_SIZE, 8,
         | 
| 117 | 
            +
                  egl.EGL_GREEN_SIZE, 8, egl.EGL_RED_SIZE, 8, egl.EGL_DEPTH_SIZE, 24,
         | 
| 118 | 
            +
                  egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT, egl.EGL_NONE
         | 
| 119 | 
            +
              ]
         | 
| 120 | 
            +
              config_attribs = (egl.EGLint * len(config_attribs))(*config_attribs)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
              num_configs = egl.EGLint()
         | 
| 123 | 
            +
              egl_cfg = egl.EGLConfig()
         | 
| 124 | 
            +
              egl.eglChooseConfig(egl_display, config_attribs, pointer(egl_cfg), 1,
         | 
| 125 | 
            +
                                  pointer(num_configs))
         | 
| 126 | 
            +
             | 
| 127 | 
            +
              width, height = surface_size
         | 
| 128 | 
            +
              pbuffer_attribs = [
         | 
| 129 | 
            +
                  egl.EGL_WIDTH,
         | 
| 130 | 
            +
                  width,
         | 
| 131 | 
            +
                  egl.EGL_HEIGHT,
         | 
| 132 | 
            +
                  height,
         | 
| 133 | 
            +
                  egl.EGL_NONE,
         | 
| 134 | 
            +
              ]
         | 
| 135 | 
            +
              pbuffer_attribs = (egl.EGLint * len(pbuffer_attribs))(*pbuffer_attribs)
         | 
| 136 | 
            +
              egl_surf = egl.eglCreatePbufferSurface(egl_display, egl_cfg, pbuffer_attribs)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
              egl.eglBindAPI(egl.EGL_OPENGL_API)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
              egl_context = egl.eglCreateContext(egl_display, egl_cfg, egl.EGL_NO_CONTEXT,
         | 
| 141 | 
            +
                                                 None)
         | 
| 142 | 
            +
              egl.eglMakeCurrent(egl_display, egl_surf, egl_surf, egl_context)
         | 
    	
        PIFu/lib/renderer/gl/init_gl.py
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _glut_window = None
         | 
| 2 | 
            +
            _context_inited = None
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            def initialize_GL_context(width=512, height=512, egl=False):
         | 
| 5 | 
            +
                '''
         | 
| 6 | 
            +
                default context uses GLUT
         | 
| 7 | 
            +
                '''
         | 
| 8 | 
            +
                if not egl:
         | 
| 9 | 
            +
                    import OpenGL.GLUT as GLUT      
         | 
| 10 | 
            +
                    display_mode = GLUT.GLUT_DOUBLE | GLUT.GLUT_RGB | GLUT.GLUT_DEPTH
         | 
| 11 | 
            +
                    global _glut_window
         | 
| 12 | 
            +
                    if _glut_window is None:
         | 
| 13 | 
            +
                        GLUT.glutInit()
         | 
| 14 | 
            +
                        GLUT.glutInitDisplayMode(display_mode)
         | 
| 15 | 
            +
                        GLUT.glutInitWindowSize(width, height)
         | 
| 16 | 
            +
                        GLUT.glutInitWindowPosition(0, 0)
         | 
| 17 | 
            +
                        _glut_window = GLUT.glutCreateWindow("My Render.")
         | 
| 18 | 
            +
                else:
         | 
| 19 | 
            +
                    from .glcontext import create_opengl_context
         | 
| 20 | 
            +
                    global _context_inited
         | 
| 21 | 
            +
                    if _context_inited is None:
         | 
| 22 | 
            +
                        create_opengl_context((width, height))
         | 
| 23 | 
            +
                        _context_inited = True
         | 
| 24 | 
            +
             | 
