Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	first
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +2 -0
- PTI/.gitignore +4 -0
- PTI/LICENSE +21 -0
- PTI/README.md +230 -0
- PTI/__init__.py +0 -0
- PTI/configs/__init__.py +0 -0
- PTI/configs/evaluation_config.py +1 -0
- PTI/configs/global_config.py +12 -0
- PTI/configs/hyperparameters.py +28 -0
- PTI/configs/paths_config.py +31 -0
- PTI/criteria/__init__.py +0 -0
- PTI/criteria/l2_loss.py +8 -0
- PTI/criteria/localitly_regulizer.py +65 -0
- PTI/dnnlib/__init__.py +9 -0
- PTI/dnnlib/util.py +477 -0
- PTI/docs/joker_original.jpg +3 -0
- PTI/docs/joker_rotation.jpg +3 -0
- PTI/docs/model_rec.jpg +3 -0
- PTI/docs/stern_rotation.jpg +3 -0
- PTI/docs/teaser.jpg +3 -0
- PTI/docs/tyron_edit.jpg +3 -0
- PTI/docs/tyron_original.jpg +3 -0
- PTI/editings/ganspace.py +21 -0
- PTI/editings/ganspace_pca/ffhq_pca.pt +0 -0
- PTI/editings/interfacegan_directions/age.pt +0 -0
- PTI/editings/interfacegan_directions/rotation.pt +0 -0
- PTI/editings/interfacegan_directions/smile.pt +0 -0
- PTI/editings/latent_editor.py +23 -0
- PTI/evaluation/experiment_setting_creator.py +43 -0
- PTI/evaluation/qualitative_edit_comparison.py +156 -0
- PTI/models/StyleCLIP/__init__.py +0 -0
- PTI/models/StyleCLIP/criteria/__init__.py +0 -0
- PTI/models/StyleCLIP/criteria/clip_loss.py +17 -0
- PTI/models/StyleCLIP/criteria/id_loss.py +39 -0
- PTI/models/StyleCLIP/global_directions/GUI.py +103 -0
- PTI/models/StyleCLIP/global_directions/GenerateImg.py +50 -0
- PTI/models/StyleCLIP/global_directions/GetCode.py +232 -0
- PTI/models/StyleCLIP/global_directions/GetGUIData.py +67 -0
- PTI/models/StyleCLIP/global_directions/Inference.py +106 -0
- PTI/models/StyleCLIP/global_directions/MapTS.py +394 -0
- PTI/models/StyleCLIP/global_directions/PlayInteractively.py +197 -0
- PTI/models/StyleCLIP/global_directions/SingleChannel.py +109 -0
- PTI/models/StyleCLIP/global_directions/__init__.py +0 -0
- PTI/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy +3 -0
- PTI/models/StyleCLIP/global_directions/dnnlib/__init__.py +9 -0
- PTI/models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py +20 -0
- PTI/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py +193 -0
- PTI/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py +181 -0
- PTI/models/StyleCLIP/global_directions/dnnlib/tflib/network.py +781 -0
- PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py +9 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -1,3 +1,5 @@ | |
| 1 | 
             
            *.gif filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
             
            *.png filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
             
            *.jpg filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | 
|  | |
| 1 | 
             
            *.gif filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
             
            *.png filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
             
            *.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            . filter=lfs diff=lfs merge=lfs -text
         | 
    	
        PTI/.gitignore
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            checkpoints
         | 
| 2 | 
            +
            __pycache__
         | 
| 3 | 
            +
            embeddings
         | 
| 4 | 
            +
            test
         | 
    	
        PTI/LICENSE
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) 2021 Daniel Roich
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
    	
        PTI/README.md
    ADDED
    
    | @@ -0,0 +1,230 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # PTI: Pivotal Tuning for Latent-based editing of Real Images     (ACM TOG 2022)
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            <!-- > Recently, a surge of advanced facial editing techniques have been proposed
         | 
| 4 | 
            +
            that leverage the generative power of a pre-trained StyleGAN. To successfully
         | 
| 5 | 
            +
            edit an image this way, one must first project (or invert) the image into
         | 
| 6 | 
            +
            the pre-trained generator’s domain. As it turns out, however, StyleGAN’s
         | 
| 7 | 
            +
            latent space induces an inherent tradeoff between distortion and editability,
         | 
| 8 | 
            +
            i.e. between maintaining the original appearance and convincingly altering
         | 
| 9 | 
            +
            some of its attributes. Practically, this means it is still challenging to
         | 
| 10 | 
            +
            apply ID-preserving facial latent-space editing to faces which are out of the
         | 
| 11 | 
            +
            generator’s domain. In this paper, we present an approach to bridge this
         | 
| 12 | 
            +
            gap. Our technique slightly alters the generator, so that an out-of-domain
         | 
| 13 | 
            +
            image is faithfully mapped into an in-domain latent code. The key idea is
         | 
| 14 | 
            +
            pivotal tuning — a brief training process that preserves the editing quality
         | 
| 15 | 
            +
            of an in-domain latent region, while changing its portrayed identity and
         | 
| 16 | 
            +
            appearance. In Pivotal Tuning Inversion (PTI), an initial inverted latent code
         | 
| 17 | 
            +
            serves as a pivot, around which the generator is fined-tuned. At the same
         | 
| 18 | 
            +
            time, a regularization term keeps nearby identities intact, to locally contain
         | 
| 19 | 
            +
            the effect. This surgical training process ends up altering appearance features
         | 
| 20 | 
            +
            that represent mostly identity, without affecting editing capabilities.
         | 
| 21 | 
            +
            To supplement this, we further show that pivotal tuning can also adjust the
         | 
| 22 | 
            +
            generator to accommodate a multitude of faces, while introducing negligible
         | 
| 23 | 
            +
            distortion on the rest of the domain. We validate our technique through
         | 
| 24 | 
            +
            inversion and editing metrics, and show preferable scores to state-of-the-art
         | 
| 25 | 
            +
            methods. We further qualitatively demonstrate our technique by applying
         | 
| 26 | 
            +
            advanced edits (such as pose, age, or expression) to numerous images of
         | 
| 27 | 
            +
            well-known and recognizable identities. Finally, we demonstrate resilience
         | 
| 28 | 
            +
            to harder cases, including heavy make-up, elaborate hairstyles and/or headwear,
         | 
| 29 | 
            +
            which otherwise could not have been successfully inverted and edited
         | 
| 30 | 
            +
            by state-of-the-art methods. -->
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            <a href="https://arxiv.org/abs/2106.05744"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
         | 
| 33 | 
            +
            <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>  
         | 
| 34 | 
            +
            Inference Notebook: <a href="https://colab.research.google.com/github/danielroich/PTI/blob/main/notebooks/inference_playground.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=20></a>  
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            <p align="center">
         | 
| 37 | 
            +
            <img src="docs/teaser.jpg"/>  
         | 
| 38 | 
            +
            <br>
         | 
| 39 | 
            +
            Pivotal Tuning Inversion (PTI) enables employing off-the-shelf latent based
         | 
| 40 | 
            +
            semantic editing techniques on real images using StyleGAN. 
         | 
| 41 | 
            +
            PTI excels in identity preserving edits, portrayed through recognizable figures —
         | 
| 42 | 
            +
            Serena Williams and Robert Downey Jr. (top), and in handling faces which
         | 
| 43 | 
            +
            are clearly out-of-domain, e.g., due to heavy makeup (bottom).
         | 
| 44 | 
            +
            </br>
         | 
| 45 | 
            +
            </p>
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            ## Description   
         | 
| 48 | 
            +
            Official Implementation of our PTI paper + code for evaluation metrics. PTI introduces an optimization mechanizem for solving the StyleGAN inversion task.
         | 
| 49 | 
            +
            Providing near-perfect reconstruction results while maintaining the high editing abilitis of the native StyleGAN latent space W. For more details, see <a href="https://arxiv.org/abs/2106.05744"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            ## Recent Updates
         | 
| 52 | 
            +
            **2021.07.01**: Fixed files download phase in the inference notebook. Which might caused the notebook not to run smoothly.
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            **2021.06.29**: Added support for CPU. In order to run PTI on CPU please change `device` parameter under `configs/global_config.py` to "cpu" instead of "cuda".
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            **2021.06.25** : Adding mohawk edit using StyleCLIP+PTI in inference notebook.
         | 
| 57 | 
            +
            	      Updating documentation in inference notebook due to Google Drive rate limit reached.
         | 
| 58 | 
            +
            	      Currently, Google Drive does not allow to download the pretrined models using Colab automatically. Manual intervention might be needed.
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            ## Getting Started
         | 
| 61 | 
            +
            ### Prerequisites
         | 
| 62 | 
            +
            - Linux or macOS
         | 
| 63 | 
            +
            - NVIDIA GPU + CUDA CuDNN (Not mandatory bur recommended)
         | 
| 64 | 
            +
            - Python 3
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            ### Installation
         | 
| 67 | 
            +
            - Dependencies:  
         | 
| 68 | 
            +
            	1. lpips
         | 
| 69 | 
            +
            	2. wandb
         | 
| 70 | 
            +
            	3. pytorch
         | 
| 71 | 
            +
            	4. torchvision
         | 
| 72 | 
            +
            	5. matplotlib
         | 
| 73 | 
            +
            	6. dlib
         | 
| 74 | 
            +
            - All dependencies can be installed using *pip install* and the package name
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            ## Pretrained Models
         | 
| 77 | 
            +
            Please download the pretrained models from the following links.
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            ### Auxiliary Models
         | 
| 80 | 
            +
            We provide various auxiliary models needed for PTI inversion task.  
         | 
| 81 | 
            +
            This includes the StyleGAN generator and pre-trained models used for loss computation.
         | 
| 82 | 
            +
            | Path | Description
         | 
| 83 | 
            +
            | :--- | :----------
         | 
| 84 | 
            +
            |[FFHQ StyleGAN](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl) | StyleGAN2-ada model trained on FFHQ with 1024x1024 output resolution.
         | 
| 85 | 
            +
            |[Dlib alignment](https://drive.google.com/file/d/1HKmjg6iXsWr4aFPuU0gBXPGR83wqMzq7/view?usp=sharing) | Dlib alignment used for images preproccessing.
         | 
| 86 | 
            +
            |[FFHQ e4e encoder](https://drive.google.com/file/d/1ALC5CLA89Ouw40TwvxcwebhzWXM5YSCm/view?usp=sharing) | Pretrained e4e encoder. Used for StyleCLIP editing.
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            Note: The StyleGAN model is used directly from the official [stylegan2-ada-pytorch implementation](https://github.com/NVlabs/stylegan2-ada-pytorch).
         | 
| 89 | 
            +
            For StyleCLIP pretrained mappers, please see [StyleCLIP's official routes](https://github.com/orpatashnik/StyleCLIP/blob/main/utils.py)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            By default, we assume that all auxiliary models are downloaded and saved to the directory `pretrained_models`. 
         | 
| 93 | 
            +
            However, you may use your own paths by changing the necessary values in `configs/path_configs.py`. 
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            ## Inversion
         | 
| 97 | 
            +
            ### Preparing your Data
         | 
| 98 | 
            +
            In order to invert a real image and edit it you should first align and crop it to the correct size. To do so you should perform *One* of the following steps: 
         | 
| 99 | 
            +
            1. Run `notebooks/align_data.ipynb` and change the "images_path" variable to the raw images path
         | 
| 100 | 
            +
            2. Run `utils/align_data.py` and change the "images_path" variable to the raw images path
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            ### Weights And Biases
         | 
| 104 | 
            +
            The project supports [Weights And Biases](https://wandb.ai/home) framework for experiment tracking. For the inversion task it enables visualization of the losses progression and the generator intermediate results during the initial inversion and the *Pivotal Tuning*(PT) procedure.
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            The log frequency can be adjusted using the parameters defined at `configs/global_config.py` under the "Logs" subsection.
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            There is no no need to have an account. However, in order to use the features provided by Weights and Biases you first have to register on their site.
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            ### Running PTI
         | 
| 112 | 
            +
            The main training script is `scripts/run_pti.py`. The script receives aligned and cropped images from paths configured in the "Input info" subscetion in
         | 
| 113 | 
            +
             `configs/paths_config.py`. 
         | 
| 114 | 
            +
            Results are saved to directories found at "Dirs for output files" under `configs/paths_config.py`. This includes inversion latent codes and tuned generators. 
         | 
| 115 | 
            +
            The hyperparametrs for the inversion task can be found at  `configs/hyperparameters.py`. They are intilized to the default values used in the paper. 
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            ## Editing
         | 
| 118 | 
            +
            By default, we assume that all auxiliary edit directions are downloaded and saved to the directory `editings`. 
         | 
| 119 | 
            +
            However, you may use your own paths by changing the necessary values in `configs/path_configs.py` under "Edit directions" subsection.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            Example of editing code can be found at `scripts/latent_editor_wrapper.py`
         | 
| 122 | 
            +
             | 
| 123 | 
            +
            ## Inference Notebooks
         | 
| 124 | 
            +
            To help visualize the results of PTI we provide a Jupyter notebook found in `notebooks/inference_playground.ipynb`.   
         | 
| 125 | 
            +
            The notebook will download the pretrained models and run inference on a sample image found online or 
         | 
| 126 | 
            +
            on images of your choosing. It is recommended to run this in [Google Colab](https://colab.research.google.com/github/danielroich/PTI/blob/main/notebooks/inference_playground.ipynb).
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            The notebook demonstrates how to:
         | 
| 129 | 
            +
            - Invert an image using PTI
         | 
| 130 | 
            +
            - Visualise the inversion and use the PTI output
         | 
| 131 | 
            +
            - Edit the image after PTI using InterfaceGAN and StyleCLIP
         | 
| 132 | 
            +
            - Compare to other inversion methods
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            ## Evaluation
         | 
| 135 | 
            +
            Currently the repository supports qualitative evaluation for reconstruction of: PTI, SG2 (*W Space*), e4e, SG2Plus (*W+ Space*). 
         | 
| 136 | 
            +
            As well as editing using InterfaceGAN and GANSpace for the same inversion methods.
         | 
| 137 | 
            +
            To run the evaluation please see `evaluation/qualitative_edit_comparison.py`. Examples of the evaluation scripts are:
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            <p align="center">
         | 
| 140 | 
            +
            <img src="docs/model_rec.jpg"/>  
         | 
| 141 | 
            +
            <br>
         | 
| 142 | 
            +
            Reconsturction comparison between different methods. The images order is: Original image, W+ inversion, e4e inversion, W inversion, PTI inversion
         | 
| 143 | 
            +
            </br>  
         | 
| 144 | 
            +
            </p>
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            <p align="center">
         | 
| 147 | 
            +
            <img src="docs/stern_rotation.jpg"/>  
         | 
| 148 | 
            +
            <br>
         | 
| 149 | 
            +
            InterfaceGAN pose edit comparison between different methods. The images order is: Original, W+, e4e, W, PTI
         | 
| 150 | 
            +
            </br>  
         | 
| 151 | 
            +
            </p>
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            <p align="center">
         | 
| 154 | 
            +
            <img src="docs/tyron_original.jpg" width="220" height="220"/>  
         | 
| 155 | 
            +
            <img src="docs/tyron_edit.jpg" width="220" height="220"/>
         | 
| 156 | 
            +
            <br>
         | 
| 157 | 
            +
            Image per edit or several edits without comparison
         | 
| 158 | 
            +
            </br>  
         | 
| 159 | 
            +
            </p>
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            ###  Coming Soon - Quantitative evaluation and StyleCLIP qualitative evaluation
         | 
| 162 | 
            +
             | 
| 163 | 
            +
            ## Repository structure
         | 
| 164 | 
            +
            | Path | Description <img width=200>
         | 
| 165 | 
            +
            | :--- | :---
         | 
| 166 | 
            +
            | ├  configs | Folder containing configs defining Hyperparameters, paths and logging
         | 
| 167 | 
            +
            | ├  criteria | Folder containing various loss and regularization criterias for the optimization
         | 
| 168 | 
            +
            | ├  dnnlib | Folder containing internal utils for StyleGAN2-ada
         | 
| 169 | 
            +
            | ├  docs | Folder containing the latent space edit directions
         | 
| 170 | 
            +
            | ├  editings | Folder containing images displayed in the README
         | 
| 171 | 
            +
            | ├  environment | Folder containing Anaconda environment used in our experiments
         | 
| 172 | 
            +
            | ├  licenses | Folder containing licenses of the open source projects used in this repository
         | 
| 173 | 
            +
            | ├  models | Folder containing models used in different editing techniques and first phase inversion
         | 
| 174 | 
            +
            | ├  notebooks | Folder with jupyter notebooks to demonstrate the usage of PTI end-to-end
         | 
| 175 | 
            +
            | ├  scripts | Folder with running scripts for inversion, editing and metric computations
         | 
| 176 | 
            +
            | ├  torch_utils | Folder containing internal utils for StyleGAN2-ada
         | 
| 177 | 
            +
            | ├  training | Folder containing the core training logic of PTI
         | 
| 178 | 
            +
            | ├  utils | Folder with various utility functions
         | 
| 179 | 
            +
             | 
| 180 | 
            +
             | 
| 181 | 
            +
            ## Credits
         | 
| 182 | 
            +
            **StyleGAN2-ada model and implementation:**  
         | 
| 183 | 
            +
            https://github.com/NVlabs/stylegan2-ada-pytorch
         | 
| 184 | 
            +
            Copyright © 2021, NVIDIA Corporation.  
         | 
| 185 | 
            +
            Nvidia Source Code License https://nvlabs.github.io/stylegan2-ada-pytorch/license.html
         | 
| 186 | 
            +
             | 
| 187 | 
            +
            **LPIPS model and implementation:**  
         | 
| 188 | 
            +
            https://github.com/richzhang/PerceptualSimilarity  
         | 
| 189 | 
            +
            Copyright (c) 2020, Sou Uchida  
         | 
| 190 | 
            +
            License (BSD 2-Clause) https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE
         | 
| 191 | 
            +
             | 
| 192 | 
            +
            **e4e model and implementation:**   
         | 
| 193 | 
            +
            https://github.com/omertov/encoder4editing
         | 
| 194 | 
            +
            Copyright (c) 2021 omertov  
         | 
| 195 | 
            +
            License (MIT) https://github.com/omertov/encoder4editing/blob/main/LICENSE
         | 
| 196 | 
            +
             | 
| 197 | 
            +
            **StyleCLIP model and implementation:**   
         | 
| 198 | 
            +
            https://github.com/orpatashnik/StyleCLIP
         | 
| 199 | 
            +
            Copyright (c) 2021 orpatashnik  
         | 
| 200 | 
            +
            License (MIT) https://github.com/orpatashnik/StyleCLIP/blob/main/LICENSE
         | 
| 201 | 
            +
             | 
| 202 | 
            +
            **InterfaceGAN implementation:**   
         | 
| 203 | 
            +
            https://github.com/genforce/interfacegan
         | 
| 204 | 
            +
            Copyright (c) 2020 genforce  
         | 
| 205 | 
            +
            License (MIT) https://github.com/genforce/interfacegan/blob/master/LICENSE
         | 
| 206 | 
            +
             | 
| 207 | 
            +
            **GANSpace implementation:**   
         | 
| 208 | 
            +
            https://github.com/harskish/ganspace
         | 
| 209 | 
            +
            Copyright (c) 2020 harkish  
         | 
| 210 | 
            +
            License (Apache License 2.0) https://github.com/harskish/ganspace/blob/master/LICENSE
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            ## Acknowledgments
         | 
| 214 | 
            +
            This repository structure is based on [encoder4editing](https://github.com/omertov/encoder4editing) and [ReStyle](https://github.com/yuval-alaluf/restyle-encoder) repositories
         | 
| 215 | 
            +
             | 
| 216 | 
            +
            ## Contact
         | 
| 217 | 
            +
            For any inquiry please contact us at our email addresses: [email protected] or [email protected]
         | 
| 218 | 
            +
             | 
| 219 | 
            +
             | 
| 220 | 
            +
            ## Citation
         | 
| 221 | 
            +
            If you use this code for your research, please cite:
         | 
| 222 | 
            +
            ```
         | 
| 223 | 
            +
            @article{roich2021pivotal,
         | 
| 224 | 
            +
              title={Pivotal Tuning for Latent-based Editing of Real Images},
         | 
| 225 | 
            +
              author={Roich, Daniel and Mokady, Ron and Bermano, Amit H and Cohen-Or, Daniel},
         | 
| 226 | 
            +
              publisher = {Association for Computing Machinery},
         | 
| 227 | 
            +
              journal={ACM Trans. Graph.},
         | 
| 228 | 
            +
              year={2021}
         | 
| 229 | 
            +
            }
         | 
| 230 | 
            +
            ```
         | 
    	
        PTI/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        PTI/configs/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        PTI/configs/evaluation_config.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            evaluated_methods = ['e4e', 'SG2', 'SG2Plus']
         | 
    	
        PTI/configs/global_config.py
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Device
         | 
| 2 | 
            +
            cuda_visible_devices = '0'
         | 
| 3 | 
            +
            device = 'cuda:0'
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # Logs
         | 
| 6 | 
            +
            training_step = 1
         | 
| 7 | 
            +
            image_rec_result_log_snapshot = 100
         | 
| 8 | 
            +
            pivotal_training_steps = 0
         | 
| 9 | 
            +
            model_snapshot_interval = 400
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Run name to be updated during PTI
         | 
| 12 | 
            +
            run_name = ''
         | 
    	
        PTI/configs/hyperparameters.py
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ## Architechture
         | 
| 2 | 
            +
            lpips_type = "alex"
         | 
| 3 | 
            +
            first_inv_type = "w"
         | 
| 4 | 
            +
            optim_type = "adam"
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            ## Locality regularization
         | 
| 7 | 
            +
            latent_ball_num_of_samples = 1
         | 
| 8 | 
            +
            locality_regularization_interval = 1
         | 
| 9 | 
            +
            use_locality_regularization = False
         | 
| 10 | 
            +
            regulizer_l2_lambda = 0.1
         | 
| 11 | 
            +
            regulizer_lpips_lambda = 0.1
         | 
| 12 | 
            +
            regulizer_alpha = 30
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            ## Loss
         | 
| 15 | 
            +
            pt_l2_lambda = 1
         | 
| 16 | 
            +
            pt_lpips_lambda = 1
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            ## Steps
         | 
| 19 | 
            +
            LPIPS_value_threshold = 0.06
         | 
| 20 | 
            +
            max_pti_steps = 350
         | 
| 21 | 
            +
            first_inv_steps = 450
         | 
| 22 | 
            +
            max_images_to_invert = 30
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            ## Optimization
         | 
| 25 | 
            +
            pti_learning_rate = 3e-4
         | 
| 26 | 
            +
            first_inv_lr = 5e-3
         | 
| 27 | 
            +
            train_batch_size = 1
         | 
| 28 | 
            +
            use_last_w_pivots = False
         | 
    	
        PTI/configs/paths_config.py
    ADDED
    
    | @@ -0,0 +1,31 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ## Pretrained models paths
         | 
| 2 | 
            +
            e4e = 'PTI/pretrained_models/e4e_ffhq_encode.pt'
         | 
| 3 | 
            +
            stylegan2_ada_ffhq = '../PTI/pretrained_models/ffhq.pkl'
         | 
| 4 | 
            +
            style_clip_pretrained_mappers = ''
         | 
| 5 | 
            +
            ir_se50 = 'PTI/pretrained_models/model_ir_se50.pth'
         | 
| 6 | 
            +
            dlib = 'PTI/pretrained_models/align.dat'
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            ## Dirs for output files
         | 
| 9 | 
            +
            checkpoints_dir = 'PTI/checkpoints'
         | 
| 10 | 
            +
            embedding_base_dir = 'PTI/embeddings'
         | 
| 11 | 
            +
            styleclip_output_dir = 'PTI/StyleCLIP_results'
         | 
| 12 | 
            +
            experiments_output_dir = 'PTI/output'
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            ## Input info
         | 
| 15 | 
            +
            ### Input dir, where the images reside
         | 
| 16 | 
            +
            input_data_path = ''
         | 
| 17 | 
            +
            ### Inversion identifier, used to keeping track of the inversion results. Both the latent code and the generator
         | 
| 18 | 
            +
            input_data_id = 'barcelona'
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            ## Keywords
         | 
| 21 | 
            +
            pti_results_keyword = 'PTI'
         | 
| 22 | 
            +
            e4e_results_keyword = 'e4e'
         | 
| 23 | 
            +
            sg2_results_keyword = 'SG2'
         | 
| 24 | 
            +
            sg2_plus_results_keyword = 'SG2_plus'
         | 
| 25 | 
            +
            multi_id_model_type = 'multi_id'
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            ## Edit directions
         | 
| 28 | 
            +
            interfacegan_age = 'PTI/editings/interfacegan_directions/age.pt'
         | 
| 29 | 
            +
            interfacegan_smile = 'PTI/editings/interfacegan_directions/smile.pt'
         | 
| 30 | 
            +
            interfacegan_rotation = 'PTI/editings/interfacegan_directions/rotation.pt'
         | 
| 31 | 
            +
            ffhq_pca = 'PTI/editings/ganspace_pca/ffhq_pca.pt'
         | 
    	
        PTI/criteria/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        PTI/criteria/l2_loss.py
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            l2_criterion = torch.nn.MSELoss(reduction='mean')
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def l2_loss(real_images, generated_images):
         | 
| 7 | 
            +
                loss = l2_criterion(real_images, generated_images)
         | 
| 8 | 
            +
                return loss
         | 
    	
        PTI/criteria/localitly_regulizer.py
    ADDED
    
    | @@ -0,0 +1,65 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from PTI.criteria import l2_loss
         | 
| 4 | 
            +
            from PTI.configs import hyperparameters
         | 
| 5 | 
            +
            from PTI.configs import global_config
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class Space_Regulizer:
         | 
| 9 | 
            +
                def __init__(self, original_G, lpips_net):
         | 
| 10 | 
            +
                    self.original_G = original_G
         | 
| 11 | 
            +
                    self.morphing_regulizer_alpha = hyperparameters.regulizer_alpha
         | 
| 12 | 
            +
                    self.lpips_loss = lpips_net
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def get_morphed_w_code(self, new_w_code, fixed_w):
         | 
| 15 | 
            +
                    interpolation_direction = new_w_code - fixed_w
         | 
| 16 | 
            +
                    interpolation_direction_norm = torch.norm(interpolation_direction, p=2)
         | 
| 17 | 
            +
                    direction_to_move = hyperparameters.regulizer_alpha * \
         | 
| 18 | 
            +
                        interpolation_direction / interpolation_direction_norm
         | 
| 19 | 
            +
                    result_w = fixed_w + direction_to_move
         | 
| 20 | 
            +
                    self.morphing_regulizer_alpha * fixed_w + \
         | 
| 21 | 
            +
                        (1 - self.morphing_regulizer_alpha) * new_w_code
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    return result_w
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def get_image_from_ws(self, w_codes, G):
         | 
| 26 | 
            +
                    return torch.cat([G.synthesis(w_code, noise_mode='none', force_fp32=True) for w_code in w_codes])
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def ball_holder_loss_lazy(self, new_G, num_of_sampled_latents, w_batch, use_wandb=False):
         | 
| 29 | 
            +
                    loss = 0.0
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    z_samples = np.random.randn(
         | 
| 32 | 
            +
                        num_of_sampled_latents, self.original_G.z_dim)
         | 
| 33 | 
            +
                    w_samples = self.original_G.mapping(torch.from_numpy(z_samples).to(global_config.device), None,
         | 
| 34 | 
            +
                                                        truncation_psi=0.5)
         | 
| 35 | 
            +
                    territory_indicator_ws = [self.get_morphed_w_code(
         | 
| 36 | 
            +
                        w_code.unsqueeze(0), w_batch) for w_code in w_samples]
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    for w_code in territory_indicator_ws:
         | 
| 39 | 
            +
                        new_img = new_G.synthesis(
         | 
| 40 | 
            +
                            w_code, noise_mode='none', force_fp32=True)
         | 
| 41 | 
            +
                        with torch.no_grad():
         | 
| 42 | 
            +
                            old_img = self.original_G.synthesis(
         | 
| 43 | 
            +
                                w_code, noise_mode='none', force_fp32=True)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                        if hyperparameters.regulizer_l2_lambda > 0:
         | 
| 46 | 
            +
                            l2_loss_val = l2_loss.l2_loss(old_img, new_img)
         | 
| 47 | 
            +
                            if use_wandb:
         | 
| 48 | 
            +
                                wandb.log({f'space_regulizer_l2_loss_val': l2_loss_val.detach().cpu()},
         | 
| 49 | 
            +
                                          step=global_config.training_step)
         | 
| 50 | 
            +
                            loss += l2_loss_val * hyperparameters.regulizer_l2_lambda
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                        if hyperparameters.regulizer_lpips_lambda > 0:
         | 
| 53 | 
            +
                            loss_lpips = self.lpips_loss(old_img, new_img)
         | 
| 54 | 
            +
                            loss_lpips = torch.mean(torch.squeeze(loss_lpips))
         | 
| 55 | 
            +
                            if use_wandb:
         | 
| 56 | 
            +
                                wandb.log({f'space_regulizer_lpips_loss_val': loss_lpips.detach().cpu()},
         | 
| 57 | 
            +
                                          step=global_config.training_step)
         | 
| 58 | 
            +
                            loss += loss_lpips * hyperparameters.regulizer_lpips_lambda
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    return loss / len(territory_indicator_ws)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def space_regulizer_loss(self, new_G, w_batch, use_wandb):
         | 
| 63 | 
            +
                    ret_val = self.ball_holder_loss_lazy(
         | 
| 64 | 
            +
                        new_G, hyperparameters.latent_ball_num_of_samples, w_batch, use_wandb)
         | 
| 65 | 
            +
                    return ret_val
         | 
    	
        PTI/dnnlib/__init__.py
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         | 
| 4 | 
            +
            # and proprietary rights in and to this software, related documentation
         | 
| 5 | 
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         | 
| 6 | 
            +
            # distribution of this software and related documentation without an express
         | 
| 7 | 
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .util import EasyDict, make_cache_dir_path
         | 
    	
        PTI/dnnlib/util.py
    ADDED
    
    | @@ -0,0 +1,477 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         | 
| 4 | 
            +
            # and proprietary rights in and to this software, related documentation
         | 
| 5 | 
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         | 
| 6 | 
            +
            # distribution of this software and related documentation without an express
         | 
| 7 | 
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """Miscellaneous utility classes and functions."""
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import ctypes
         | 
| 12 | 
            +
            import fnmatch
         | 
| 13 | 
            +
            import importlib
         | 
| 14 | 
            +
            import inspect
         | 
| 15 | 
            +
            import numpy as np
         | 
| 16 | 
            +
            import os
         | 
| 17 | 
            +
            import shutil
         | 
| 18 | 
            +
            import sys
         | 
| 19 | 
            +
            import types
         | 
| 20 | 
            +
            import io
         | 
| 21 | 
            +
            import pickle
         | 
| 22 | 
            +
            import re
         | 
| 23 | 
            +
            import requests
         | 
| 24 | 
            +
            import html
         | 
| 25 | 
            +
            import hashlib
         | 
| 26 | 
            +
            import glob
         | 
| 27 | 
            +
            import tempfile
         | 
| 28 | 
            +
            import urllib
         | 
| 29 | 
            +
            import urllib.request
         | 
| 30 | 
            +
            import uuid
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            from distutils.util import strtobool
         | 
| 33 | 
            +
            from typing import Any, List, Tuple, Union
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            # Util classes
         | 
| 37 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            class EasyDict(dict):
         | 
| 41 | 
            +
                """Convenience class that behaves like a dict but allows access with the attribute syntax."""
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def __getattr__(self, name: str) -> Any:
         | 
| 44 | 
            +
                    try:
         | 
| 45 | 
            +
                        return self[name]
         | 
| 46 | 
            +
                    except KeyError:
         | 
| 47 | 
            +
                        raise AttributeError(name)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def __setattr__(self, name: str, value: Any) -> None:
         | 
| 50 | 
            +
                    self[name] = value
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def __delattr__(self, name: str) -> None:
         | 
| 53 | 
            +
                    del self[name]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            class Logger(object):
         | 
| 57 | 
            +
                """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
         | 
| 60 | 
            +
                    self.file = None
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    if file_name is not None:
         | 
| 63 | 
            +
                        self.file = open(file_name, file_mode)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    self.should_flush = should_flush
         | 
| 66 | 
            +
                    self.stdout = sys.stdout
         | 
| 67 | 
            +
                    self.stderr = sys.stderr
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    sys.stdout = self
         | 
| 70 | 
            +
                    sys.stderr = self
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def __enter__(self) -> "Logger":
         | 
| 73 | 
            +
                    return self
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
         | 
| 76 | 
            +
                    self.close()
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def write(self, text: Union[str, bytes]) -> None:
         | 
| 79 | 
            +
                    """Write text to stdout (and a file) and optionally flush."""
         | 
| 80 | 
            +
                    if isinstance(text, bytes):
         | 
| 81 | 
            +
                        text = text.decode()
         | 
| 82 | 
            +
                    if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
         | 
| 83 | 
            +
                        return
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    if self.file is not None:
         | 
| 86 | 
            +
                        self.file.write(text)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    self.stdout.write(text)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    if self.should_flush:
         | 
| 91 | 
            +
                        self.flush()
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def flush(self) -> None:
         | 
| 94 | 
            +
                    """Flush written text to both stdout and a file, if open."""
         | 
| 95 | 
            +
                    if self.file is not None:
         | 
| 96 | 
            +
                        self.file.flush()
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    self.stdout.flush()
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def close(self) -> None:
         | 
| 101 | 
            +
                    """Flush, close possible files, and remove stdout/stderr mirroring."""
         | 
| 102 | 
            +
                    self.flush()
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    # if using multiple loggers, prevent closing in wrong order
         | 
| 105 | 
            +
                    if sys.stdout is self:
         | 
| 106 | 
            +
                        sys.stdout = self.stdout
         | 
| 107 | 
            +
                    if sys.stderr is self:
         | 
| 108 | 
            +
                        sys.stderr = self.stderr
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    if self.file is not None:
         | 
| 111 | 
            +
                        self.file.close()
         | 
| 112 | 
            +
                        self.file = None
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            # Cache directories
         | 
| 116 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            _dnnlib_cache_dir = None
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            def set_cache_dir(path: str) -> None:
         | 
| 121 | 
            +
                global _dnnlib_cache_dir
         | 
| 122 | 
            +
                _dnnlib_cache_dir = path
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            def make_cache_dir_path(*paths: str) -> str:
         | 
| 125 | 
            +
                if _dnnlib_cache_dir is not None:
         | 
| 126 | 
            +
                    return os.path.join(_dnnlib_cache_dir, *paths)
         | 
| 127 | 
            +
                if 'DNNLIB_CACHE_DIR' in os.environ:
         | 
| 128 | 
            +
                    return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
         | 
| 129 | 
            +
                if 'HOME' in os.environ:
         | 
| 130 | 
            +
                    return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
         | 
| 131 | 
            +
                if 'USERPROFILE' in os.environ:
         | 
| 132 | 
            +
                    return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
         | 
| 133 | 
            +
                return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            # Small util functions
         | 
| 136 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            def format_time(seconds: Union[int, float]) -> str:
         | 
| 140 | 
            +
                """Convert the seconds to human readable string with days, hours, minutes and seconds."""
         | 
| 141 | 
            +
                s = int(np.rint(seconds))
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                if s < 60:
         | 
| 144 | 
            +
                    return "{0}s".format(s)
         | 
| 145 | 
            +
                elif s < 60 * 60:
         | 
| 146 | 
            +
                    return "{0}m {1:02}s".format(s // 60, s % 60)
         | 
| 147 | 
            +
                elif s < 24 * 60 * 60:
         | 
| 148 | 
            +
                    return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
         | 
| 149 | 
            +
                else:
         | 
| 150 | 
            +
                    return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            def ask_yes_no(question: str) -> bool:
         | 
| 154 | 
            +
                """Ask the user the question until the user inputs a valid answer."""
         | 
| 155 | 
            +
                while True:
         | 
| 156 | 
            +
                    try:
         | 
| 157 | 
            +
                        print("{0} [y/n]".format(question))
         | 
| 158 | 
            +
                        return strtobool(input().lower())
         | 
| 159 | 
            +
                    except ValueError:
         | 
| 160 | 
            +
                        pass
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            def tuple_product(t: Tuple) -> Any:
         | 
| 164 | 
            +
                """Calculate the product of the tuple elements."""
         | 
| 165 | 
            +
                result = 1
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                for v in t:
         | 
| 168 | 
            +
                    result *= v
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                return result
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            _str_to_ctype = {
         | 
| 174 | 
            +
                "uint8": ctypes.c_ubyte,
         | 
| 175 | 
            +
                "uint16": ctypes.c_uint16,
         | 
| 176 | 
            +
                "uint32": ctypes.c_uint32,
         | 
| 177 | 
            +
                "uint64": ctypes.c_uint64,
         | 
| 178 | 
            +
                "int8": ctypes.c_byte,
         | 
| 179 | 
            +
                "int16": ctypes.c_int16,
         | 
| 180 | 
            +
                "int32": ctypes.c_int32,
         | 
| 181 | 
            +
                "int64": ctypes.c_int64,
         | 
| 182 | 
            +
                "float32": ctypes.c_float,
         | 
| 183 | 
            +
                "float64": ctypes.c_double
         | 
| 184 | 
            +
            }
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
         | 
| 188 | 
            +
                """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
         | 
| 189 | 
            +
                type_str = None
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                if isinstance(type_obj, str):
         | 
| 192 | 
            +
                    type_str = type_obj
         | 
| 193 | 
            +
                elif hasattr(type_obj, "__name__"):
         | 
| 194 | 
            +
                    type_str = type_obj.__name__
         | 
| 195 | 
            +
                elif hasattr(type_obj, "name"):
         | 
| 196 | 
            +
                    type_str = type_obj.name
         | 
| 197 | 
            +
                else:
         | 
| 198 | 
            +
                    raise RuntimeError("Cannot infer type name from input")
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                assert type_str in _str_to_ctype.keys()
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                my_dtype = np.dtype(type_str)
         | 
| 203 | 
            +
                my_ctype = _str_to_ctype[type_str]
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                return my_dtype, my_ctype
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            def is_pickleable(obj: Any) -> bool:
         | 
| 211 | 
            +
                try:
         | 
| 212 | 
            +
                    with io.BytesIO() as stream:
         | 
| 213 | 
            +
                        pickle.dump(obj, stream)
         | 
| 214 | 
            +
                    return True
         | 
| 215 | 
            +
                except:
         | 
| 216 | 
            +
                    return False
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
            +
            # Functionality to import modules/objects by name, and call functions by name
         | 
| 220 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 221 | 
            +
             | 
| 222 | 
            +
            def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
         | 
| 223 | 
            +
                """Searches for the underlying module behind the name to some python object.
         | 
| 224 | 
            +
                Returns the module and the object name (original name with module part removed)."""
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                # allow convenience shorthands, substitute them by full names
         | 
| 227 | 
            +
                obj_name = re.sub("^np.", "numpy.", obj_name)
         | 
| 228 | 
            +
                obj_name = re.sub("^tf.", "tensorflow.", obj_name)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                # list alternatives for (module_name, local_obj_name)
         | 
| 231 | 
            +
                parts = obj_name.split(".")
         | 
| 232 | 
            +
                name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                # try each alternative in turn
         | 
| 235 | 
            +
                for module_name, local_obj_name in name_pairs:
         | 
| 236 | 
            +
                    try:
         | 
| 237 | 
            +
                        module = importlib.import_module(module_name) # may raise ImportError
         | 
| 238 | 
            +
                        get_obj_from_module(module, local_obj_name) # may raise AttributeError
         | 
| 239 | 
            +
                        return module, local_obj_name
         | 
| 240 | 
            +
                    except:
         | 
| 241 | 
            +
                        pass
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                # maybe some of the modules themselves contain errors?
         | 
| 244 | 
            +
                for module_name, _local_obj_name in name_pairs:
         | 
| 245 | 
            +
                    try:
         | 
| 246 | 
            +
                        importlib.import_module(module_name) # may raise ImportError
         | 
| 247 | 
            +
                    except ImportError:
         | 
| 248 | 
            +
                        if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
         | 
| 249 | 
            +
                            raise
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                # maybe the requested attribute is missing?
         | 
| 252 | 
            +
                for module_name, local_obj_name in name_pairs:
         | 
| 253 | 
            +
                    try:
         | 
| 254 | 
            +
                        module = importlib.import_module(module_name) # may raise ImportError
         | 
| 255 | 
            +
                        get_obj_from_module(module, local_obj_name) # may raise AttributeError
         | 
| 256 | 
            +
                    except ImportError:
         | 
| 257 | 
            +
                        pass
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                # we are out of luck, but we have no idea why
         | 
| 260 | 
            +
                raise ImportError(obj_name)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
         | 
| 264 | 
            +
                """Traverses the object name and returns the last (rightmost) python object."""
         | 
| 265 | 
            +
                if obj_name == '':
         | 
| 266 | 
            +
                    return module
         | 
| 267 | 
            +
                obj = module
         | 
| 268 | 
            +
                for part in obj_name.split("."):
         | 
| 269 | 
            +
                    obj = getattr(obj, part)
         | 
| 270 | 
            +
                return obj
         | 
| 271 | 
            +
             | 
| 272 | 
            +
             | 
| 273 | 
            +
            def get_obj_by_name(name: str) -> Any:
         | 
| 274 | 
            +
                """Finds the python object with the given name."""
         | 
| 275 | 
            +
                module, obj_name = get_module_from_obj_name(name)
         | 
| 276 | 
            +
                return get_obj_from_module(module, obj_name)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
             | 
| 279 | 
            +
            def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
         | 
| 280 | 
            +
                """Finds the python object with the given name and calls it as a function."""
         | 
| 281 | 
            +
                assert func_name is not None
         | 
| 282 | 
            +
                func_obj = get_obj_by_name(func_name)
         | 
| 283 | 
            +
                assert callable(func_obj)
         | 
| 284 | 
            +
                return func_obj(*args, **kwargs)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
             | 
| 287 | 
            +
            def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
         | 
| 288 | 
            +
                """Finds the python class with the given name and constructs it with the given arguments."""
         | 
| 289 | 
            +
                return call_func_by_name(*args, func_name=class_name, **kwargs)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
             | 
| 292 | 
            +
            def get_module_dir_by_obj_name(obj_name: str) -> str:
         | 
| 293 | 
            +
                """Get the directory path of the module containing the given object name."""
         | 
| 294 | 
            +
                module, _ = get_module_from_obj_name(obj_name)
         | 
| 295 | 
            +
                return os.path.dirname(inspect.getfile(module))
         | 
| 296 | 
            +
             | 
| 297 | 
            +
             | 
| 298 | 
            +
            def is_top_level_function(obj: Any) -> bool:
         | 
| 299 | 
            +
                """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
         | 
| 300 | 
            +
                return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
         | 
| 301 | 
            +
             | 
| 302 | 
            +
             | 
| 303 | 
            +
            def get_top_level_function_name(obj: Any) -> str:
         | 
| 304 | 
            +
                """Return the fully-qualified name of a top-level function."""
         | 
| 305 | 
            +
                assert is_top_level_function(obj)
         | 
| 306 | 
            +
                module = obj.__module__
         | 
| 307 | 
            +
                if module == '__main__':
         | 
| 308 | 
            +
                    module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
         | 
| 309 | 
            +
                return module + "." + obj.__name__
         | 
| 310 | 
            +
             | 
| 311 | 
            +
             | 
| 312 | 
            +
            # File system helpers
         | 
| 313 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 314 | 
            +
             | 
| 315 | 
            +
            def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
         | 
| 316 | 
            +
                """List all files recursively in a given directory while ignoring given file and directory names.
         | 
| 317 | 
            +
                Returns list of tuples containing both absolute and relative paths."""
         | 
| 318 | 
            +
                assert os.path.isdir(dir_path)
         | 
| 319 | 
            +
                base_name = os.path.basename(os.path.normpath(dir_path))
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                if ignores is None:
         | 
| 322 | 
            +
                    ignores = []
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                result = []
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                for root, dirs, files in os.walk(dir_path, topdown=True):
         | 
| 327 | 
            +
                    for ignore_ in ignores:
         | 
| 328 | 
            +
                        dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                        # dirs need to be edited in-place
         | 
| 331 | 
            +
                        for d in dirs_to_remove:
         | 
| 332 | 
            +
                            dirs.remove(d)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                        files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    absolute_paths = [os.path.join(root, f) for f in files]
         | 
| 337 | 
            +
                    relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                    if add_base_to_relative:
         | 
| 340 | 
            +
                        relative_paths = [os.path.join(base_name, p) for p in relative_paths]
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    assert len(absolute_paths) == len(relative_paths)
         | 
| 343 | 
            +
                    result += zip(absolute_paths, relative_paths)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                return result
         | 
| 346 | 
            +
             | 
| 347 | 
            +
             | 
| 348 | 
            +
            def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
         | 
| 349 | 
            +
                """Takes in a list of tuples of (src, dst) paths and copies files.
         | 
| 350 | 
            +
                Will create all necessary directories."""
         | 
| 351 | 
            +
                for file in files:
         | 
| 352 | 
            +
                    target_dir_name = os.path.dirname(file[1])
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    # will create all intermediate-level directories
         | 
| 355 | 
            +
                    if not os.path.exists(target_dir_name):
         | 
| 356 | 
            +
                        os.makedirs(target_dir_name)
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    shutil.copyfile(file[0], file[1])
         | 
| 359 | 
            +
             | 
| 360 | 
            +
             | 
| 361 | 
            +
            # URL helpers
         | 
| 362 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 363 | 
            +
             | 
| 364 | 
            +
            def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
         | 
| 365 | 
            +
                """Determine whether the given object is a valid URL string."""
         | 
| 366 | 
            +
                if not isinstance(obj, str) or not "://" in obj:
         | 
| 367 | 
            +
                    return False
         | 
| 368 | 
            +
                if allow_file_urls and obj.startswith('file://'):
         | 
| 369 | 
            +
                    return True
         | 
| 370 | 
            +
                try:
         | 
| 371 | 
            +
                    res = requests.compat.urlparse(obj)
         | 
| 372 | 
            +
                    if not res.scheme or not res.netloc or not "." in res.netloc:
         | 
| 373 | 
            +
                        return False
         | 
| 374 | 
            +
                    res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
         | 
| 375 | 
            +
                    if not res.scheme or not res.netloc or not "." in res.netloc:
         | 
| 376 | 
            +
                        return False
         | 
| 377 | 
            +
                except:
         | 
| 378 | 
            +
                    return False
         | 
| 379 | 
            +
                return True
         | 
| 380 | 
            +
             | 
| 381 | 
            +
             | 
| 382 | 
            +
            def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
         | 
| 383 | 
            +
                """Download the given URL and return a binary-mode file object to access the data."""
         | 
| 384 | 
            +
                assert num_attempts >= 1
         | 
| 385 | 
            +
                assert not (return_filename and (not cache))
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                # Doesn't look like an URL scheme so interpret it as a local filename.
         | 
| 388 | 
            +
                if not re.match('^[a-z]+://', url):
         | 
| 389 | 
            +
                    return url if return_filename else open(url, "rb")
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                # Handle file URLs.  This code handles unusual file:// patterns that
         | 
| 392 | 
            +
                # arise on Windows:
         | 
| 393 | 
            +
                #
         | 
| 394 | 
            +
                # file:///c:/foo.txt
         | 
| 395 | 
            +
                #
         | 
| 396 | 
            +
                # which would translate to a local '/c:/foo.txt' filename that's
         | 
| 397 | 
            +
                # invalid.  Drop the forward slash for such pathnames.
         | 
| 398 | 
            +
                #
         | 
| 399 | 
            +
                # If you touch this code path, you should test it on both Linux and
         | 
| 400 | 
            +
                # Windows.
         | 
| 401 | 
            +
                #
         | 
| 402 | 
            +
                # Some internet resources suggest using urllib.request.url2pathname() but
         | 
| 403 | 
            +
                # but that converts forward slashes to backslashes and this causes
         | 
| 404 | 
            +
                # its own set of problems.
         | 
| 405 | 
            +
                if url.startswith('file://'):
         | 
| 406 | 
            +
                    filename = urllib.parse.urlparse(url).path
         | 
| 407 | 
            +
                    if re.match(r'^/[a-zA-Z]:', filename):
         | 
| 408 | 
            +
                        filename = filename[1:]
         | 
| 409 | 
            +
                    return filename if return_filename else open(filename, "rb")
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                assert is_url(url)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                # Lookup from cache.
         | 
| 414 | 
            +
                if cache_dir is None:
         | 
| 415 | 
            +
                    cache_dir = make_cache_dir_path('downloads')
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
         | 
| 418 | 
            +
                if cache:
         | 
| 419 | 
            +
                    cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
         | 
| 420 | 
            +
                    if len(cache_files) == 1:
         | 
| 421 | 
            +
                        filename = cache_files[0]
         | 
| 422 | 
            +
                        return filename if return_filename else open(filename, "rb")
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                # Download.
         | 
| 425 | 
            +
                url_name = None
         | 
| 426 | 
            +
                url_data = None
         | 
| 427 | 
            +
                with requests.Session() as session:
         | 
| 428 | 
            +
                    if verbose:
         | 
| 429 | 
            +
                        print("Downloading %s ..." % url, end="", flush=True)
         | 
| 430 | 
            +
                    for attempts_left in reversed(range(num_attempts)):
         | 
| 431 | 
            +
                        try:
         | 
| 432 | 
            +
                            with session.get(url) as res:
         | 
| 433 | 
            +
                                res.raise_for_status()
         | 
| 434 | 
            +
                                if len(res.content) == 0:
         | 
| 435 | 
            +
                                    raise IOError("No data received")
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                                if len(res.content) < 8192:
         | 
| 438 | 
            +
                                    content_str = res.content.decode("utf-8")
         | 
| 439 | 
            +
                                    if "download_warning" in res.headers.get("Set-Cookie", ""):
         | 
| 440 | 
            +
                                        links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
         | 
| 441 | 
            +
                                        if len(links) == 1:
         | 
| 442 | 
            +
                                            url = requests.compat.urljoin(url, links[0])
         | 
| 443 | 
            +
                                            raise IOError("Google Drive virus checker nag")
         | 
| 444 | 
            +
                                    if "Google Drive - Quota exceeded" in content_str:
         | 
| 445 | 
            +
                                        raise IOError("Google Drive download quota exceeded -- please try again later")
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                                match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
         | 
| 448 | 
            +
                                url_name = match[1] if match else url
         | 
| 449 | 
            +
                                url_data = res.content
         | 
| 450 | 
            +
                                if verbose:
         | 
| 451 | 
            +
                                    print(" done")
         | 
| 452 | 
            +
                                break
         | 
| 453 | 
            +
                        except KeyboardInterrupt:
         | 
| 454 | 
            +
                            raise
         | 
| 455 | 
            +
                        except:
         | 
| 456 | 
            +
                            if not attempts_left:
         | 
| 457 | 
            +
                                if verbose:
         | 
| 458 | 
            +
                                    print(" failed")
         | 
| 459 | 
            +
                                raise
         | 
| 460 | 
            +
                            if verbose:
         | 
| 461 | 
            +
                                print(".", end="", flush=True)
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                # Save to cache.
         | 
| 464 | 
            +
                if cache:
         | 
| 465 | 
            +
                    safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
         | 
| 466 | 
            +
                    cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
         | 
| 467 | 
            +
                    temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
         | 
| 468 | 
            +
                    os.makedirs(cache_dir, exist_ok=True)
         | 
| 469 | 
            +
                    with open(temp_file, "wb") as f:
         | 
| 470 | 
            +
                        f.write(url_data)
         | 
| 471 | 
            +
                    os.replace(temp_file, cache_file) # atomic
         | 
| 472 | 
            +
                    if return_filename:
         | 
| 473 | 
            +
                        return cache_file
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                # Return data as file object.
         | 
| 476 | 
            +
                assert not return_filename
         | 
| 477 | 
            +
                return io.BytesIO(url_data)
         | 
    	
        PTI/docs/joker_original.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        PTI/docs/joker_rotation.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        PTI/docs/model_rec.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        PTI/docs/stern_rotation.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        PTI/docs/teaser.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        PTI/docs/tyron_edit.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        PTI/docs/tyron_original.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        PTI/editings/ganspace.py
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def edit(latents, pca, edit_directions):
         | 
| 5 | 
            +
                edit_latents = []
         | 
| 6 | 
            +
                for latent in latents:
         | 
| 7 | 
            +
                    for pca_idx, start, end, strength in edit_directions:
         | 
| 8 | 
            +
                        delta = get_delta(pca, latent, pca_idx, strength)
         | 
| 9 | 
            +
                        delta_padded = torch.zeros(latent.shape).to('cuda')
         | 
| 10 | 
            +
                        delta_padded[start:end] += delta.repeat(end - start, 1)
         | 
| 11 | 
            +
                        edit_latents.append(latent + delta_padded)
         | 
| 12 | 
            +
                return torch.stack(edit_latents)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def get_delta(pca, latent, idx, strength):
         | 
| 16 | 
            +
                w_centered = latent - pca['mean'].to('cuda')
         | 
| 17 | 
            +
                lat_comp = pca['comp'].to('cuda')
         | 
| 18 | 
            +
                lat_std = pca['std'].to('cuda')
         | 
| 19 | 
            +
                w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
         | 
| 20 | 
            +
                delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
         | 
| 21 | 
            +
                return delta
         | 
    	
        PTI/editings/ganspace_pca/ffhq_pca.pt
    ADDED
    
    | Binary file (168 kB). View file | 
|  | 
    	
        PTI/editings/interfacegan_directions/age.pt
    ADDED
    
    | Binary file (2.81 kB). View file | 
|  | 
    	
        PTI/editings/interfacegan_directions/rotation.pt
    ADDED
    
    | Binary file (2.81 kB). View file | 
|  | 
    	
        PTI/editings/interfacegan_directions/smile.pt
    ADDED
    
    | Binary file (2.81 kB). View file | 
|  | 
    	
        PTI/editings/latent_editor.py
    ADDED
    
    | @@ -0,0 +1,23 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from configs import paths_config
         | 
| 4 | 
            +
            from editings import ganspace
         | 
| 5 | 
            +
            from utils.data_utils import tensor2im
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class LatentEditor(object):
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                def apply_ganspace(self, latent, ganspace_pca, edit_directions):
         | 
| 11 | 
            +
                    edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions)
         | 
| 12 | 
            +
                    return edit_latents
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def apply_interfacegan(self, latent, direction, factor=1, factor_range=None):
         | 
| 15 | 
            +
                    edit_latents = []
         | 
| 16 | 
            +
                    if factor_range is not None:  # Apply a range of editing factors. for example, (-5, 5)
         | 
| 17 | 
            +
                        for f in range(*factor_range):
         | 
| 18 | 
            +
                            edit_latent = latent + f * direction
         | 
| 19 | 
            +
                            edit_latents.append(edit_latent)
         | 
| 20 | 
            +
                        edit_latents = torch.cat(edit_latents)
         | 
| 21 | 
            +
                    else:
         | 
| 22 | 
            +
                        edit_latents = latent + factor * direction
         | 
| 23 | 
            +
                    return edit_latents
         | 
    	
        PTI/evaluation/experiment_setting_creator.py
    ADDED
    
    | @@ -0,0 +1,43 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import glob
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from configs import global_config, paths_config, hyperparameters
         | 
| 4 | 
            +
            from scripts.latent_creators.sg2_plus_latent_creator import SG2PlusLatentCreator
         | 
| 5 | 
            +
            from scripts.latent_creators.e4e_latent_creator import E4ELatentCreator
         | 
| 6 | 
            +
            from scripts.run_pti import run_PTI
         | 
| 7 | 
            +
            import pickle
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            from utils.models_utils import toogle_grad, load_old_G
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class ExperimentRunner:
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def __init__(self, run_id=''):
         | 
| 15 | 
            +
                    self.images_paths = glob.glob(f'{paths_config.input_data_path}/*')
         | 
| 16 | 
            +
                    self.target_paths = glob.glob(f'{paths_config.input_data_path}/*')
         | 
| 17 | 
            +
                    self.run_id = run_id
         | 
| 18 | 
            +
                    self.sampled_ws = None
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    self.old_G = load_old_G()
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    toogle_grad(self.old_G, False)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False):
         | 
| 25 | 
            +
                    if run_pt:
         | 
| 26 | 
            +
                        self.run_id = run_PTI(self.run_id, use_wandb=use_wandb, use_multi_id_training=use_multi_id_training)
         | 
| 27 | 
            +
                    if create_other_latents:
         | 
| 28 | 
            +
                        sg2_plus_latent_creator = SG2PlusLatentCreator(use_wandb=use_wandb)
         | 
| 29 | 
            +
                        sg2_plus_latent_creator.create_latents()
         | 
| 30 | 
            +
                        e4e_latent_creator = E4ELatentCreator(use_wandb=use_wandb)
         | 
| 31 | 
            +
                        e4e_latent_creator.create_latents()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    torch.cuda.empty_cache()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    return self.run_id
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            if __name__ == '__main__':
         | 
| 39 | 
            +
                os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
         | 
| 40 | 
            +
                os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                runner = ExperimentRunner()
         | 
| 43 | 
            +
                runner.run_experiment(True, False, False)
         | 
    	
        PTI/evaluation/qualitative_edit_comparison.py
    ADDED
    
    | @@ -0,0 +1,156 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from random import choice
         | 
| 3 | 
            +
            from string import ascii_uppercase
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            from tqdm import tqdm
         | 
| 6 | 
            +
            from scripts.latent_editor_wrapper import LatentEditorWrapper
         | 
| 7 | 
            +
            from evaluation.experiment_setting_creator import ExperimentRunner
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            from configs import paths_config, hyperparameters, evaluation_config
         | 
| 10 | 
            +
            from utils.log_utils import save_concat_image, save_single_image
         | 
| 11 | 
            +
            from utils.models_utils import load_tuned_G
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class EditComparison:
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __init__(self, save_single_images, save_concatenated_images, run_id):
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                    self.run_id = run_id
         | 
| 19 | 
            +
                    self.experiment_creator = ExperimentRunner(run_id)
         | 
| 20 | 
            +
                    self.save_single_images = save_single_images
         | 
| 21 | 
            +
                    self.save_concatenated_images = save_concatenated_images
         | 
| 22 | 
            +
                    self.latent_editor = LatentEditorWrapper()
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def save_reconstruction_images(self, image_latents, new_inv_image_latent, new_G, target_image):
         | 
| 25 | 
            +
                    if self.save_concatenated_images:
         | 
| 26 | 
            +
                        save_concat_image(self.concat_base_dir, image_latents, new_inv_image_latent, new_G,
         | 
| 27 | 
            +
                                          self.experiment_creator.old_G,
         | 
| 28 | 
            +
                                          'rec',
         | 
| 29 | 
            +
                                          target_image)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    if self.save_single_images:
         | 
| 32 | 
            +
                        save_single_image(self.single_base_dir, new_inv_image_latent, new_G, 'rec')
         | 
| 33 | 
            +
                        target_image.save(f'{self.single_base_dir}/Original.jpg')
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def create_output_dirs(self, full_image_name):
         | 
| 36 | 
            +
                    output_base_dir_path = f'{paths_config.experiments_output_dir}/{paths_config.input_data_id}/{self.run_id}/{full_image_name}'
         | 
| 37 | 
            +
                    os.makedirs(output_base_dir_path, exist_ok=True)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    self.concat_base_dir = f'{output_base_dir_path}/concat_images'
         | 
| 40 | 
            +
                    self.single_base_dir = f'{output_base_dir_path}/single_images'
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    os.makedirs(self.concat_base_dir, exist_ok=True)
         | 
| 43 | 
            +
                    os.makedirs(self.single_base_dir, exist_ok=True)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def get_image_latent_codes(self, image_name):
         | 
| 46 | 
            +
                    image_latents = []
         | 
| 47 | 
            +
                    for method in evaluation_config.evaluated_methods:
         | 
| 48 | 
            +
                        if method == 'SG2':
         | 
| 49 | 
            +
                            image_latents.append(torch.load(
         | 
| 50 | 
            +
                                f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/'
         | 
| 51 | 
            +
                                f'{paths_config.pti_results_keyword}/{image_name}/0.pt'))
         | 
| 52 | 
            +
                        else:
         | 
| 53 | 
            +
                            image_latents.append(torch.load(
         | 
| 54 | 
            +
                                f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{method}/{image_name}/0.pt'))
         | 
| 55 | 
            +
                    new_inv_image_latent = torch.load(
         | 
| 56 | 
            +
                        f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}/{image_name}/0.pt')
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    return image_latents, new_inv_image_latent
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def save_interfacegan_edits(self, image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image):
         | 
| 61 | 
            +
                    new_w_inv_edits = self.latent_editor.get_single_interface_gan_edits(new_inv_image_latent,
         | 
| 62 | 
            +
                                                                                        interfacegan_factors)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    inv_edits = []
         | 
| 65 | 
            +
                    for latent in image_latents:
         | 
| 66 | 
            +
                        inv_edits.append(self.latent_editor.get_single_interface_gan_edits(latent, interfacegan_factors))
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    for direction, edits in new_w_inv_edits.items():
         | 
| 69 | 
            +
                        for factor, edit_tensor in edits.items():
         | 
| 70 | 
            +
                            if self.save_concatenated_images:
         | 
| 71 | 
            +
                                save_concat_image(self.concat_base_dir, [edits[direction][factor] for edits in inv_edits],
         | 
| 72 | 
            +
                                                  new_w_inv_edits[direction][factor],
         | 
| 73 | 
            +
                                                  new_G,
         | 
| 74 | 
            +
                                                  self.experiment_creator.old_G,
         | 
| 75 | 
            +
                                                  f'{direction}_{factor}', target_image)
         | 
| 76 | 
            +
                            if self.save_single_images:
         | 
| 77 | 
            +
                                save_single_image(self.single_base_dir, new_w_inv_edits[direction][factor], new_G,
         | 
| 78 | 
            +
                                                  f'{direction}_{factor}')
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def save_ganspace_edits(self, image_latents, new_inv_image_latent, factors, new_G, target_image):
         | 
| 81 | 
            +
                    new_w_inv_edits = self.latent_editor.get_single_ganspace_edits(new_inv_image_latent, factors)
         | 
| 82 | 
            +
                    inv_edits = []
         | 
| 83 | 
            +
                    for latent in image_latents:
         | 
| 84 | 
            +
                        inv_edits.append(self.latent_editor.get_single_ganspace_edits(latent, factors))
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    for idx in range(len(new_w_inv_edits)):
         | 
| 87 | 
            +
                        if self.save_concatenated_images:
         | 
| 88 | 
            +
                            save_concat_image(self.concat_base_dir, [edit[idx] for edit in inv_edits], new_w_inv_edits[idx],
         | 
| 89 | 
            +
                                              new_G,
         | 
| 90 | 
            +
                                              self.experiment_creator.old_G,
         | 
| 91 | 
            +
                                              f'ganspace_{idx}', target_image)
         | 
| 92 | 
            +
                        if self.save_single_images:
         | 
| 93 | 
            +
                            save_single_image(self.single_base_dir, new_w_inv_edits[idx], new_G,
         | 
| 94 | 
            +
                                              f'ganspace_{idx}')
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False):
         | 
| 97 | 
            +
                    images_counter = 0
         | 
| 98 | 
            +
                    new_G = None
         | 
| 99 | 
            +
                    interfacegan_factors = [val / 2 for val in range(-6, 7) if val != 0]
         | 
| 100 | 
            +
                    ganspace_factors = range(-20, 25, 5)
         | 
| 101 | 
            +
                    self.experiment_creator.run_experiment(run_pt, create_other_latents, use_multi_id_training, use_wandb)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    if use_multi_id_training:
         | 
| 104 | 
            +
                        new_G = load_tuned_G(self.run_id, paths_config.multi_id_model_type)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    for idx, image_path in tqdm(enumerate(self.experiment_creator.images_paths),
         | 
| 107 | 
            +
                                                total=len(self.experiment_creator.images_paths)):
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                        if images_counter >= hyperparameters.max_images_to_invert:
         | 
| 110 | 
            +
                            break
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                        image_name = image_path.split('.')[0].split('/')[-1]
         | 
| 113 | 
            +
                        target_image = Image.open(self.experiment_creator.target_paths[idx])
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                        if not use_multi_id_training:
         | 
| 116 | 
            +
                            new_G = load_tuned_G(self.run_id, image_name)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                        image_latents, new_inv_image_latent = self.get_image_latent_codes(image_name)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                        self.create_output_dirs(image_name)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                        self.save_reconstruction_images(image_latents, new_inv_image_latent, new_G, target_image)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                        self.save_interfacegan_edits(image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        self.save_ganspace_edits(image_latents, new_inv_image_latent, ganspace_factors, new_G, target_image)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                        target_image.close()
         | 
| 129 | 
            +
                        torch.cuda.empty_cache()
         | 
| 130 | 
            +
                        images_counter += 1
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            def run_pti_and_full_edit(iid):
         | 
| 134 | 
            +
                evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2']
         | 
| 135 | 
            +
                edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True,
         | 
| 136 | 
            +
                                                     run_id=f'{paths_config.input_data_id}_pti_full_edit_{iid}')
         | 
| 137 | 
            +
                edit_figure_creator.run_experiment(True, True, use_multi_id_training=False, use_wandb=False)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            def pti_no_comparison(iid):
         | 
| 141 | 
            +
                evaluation_config.evaluated_methods = []
         | 
| 142 | 
            +
                edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True,
         | 
| 143 | 
            +
                                                     run_id=f'{paths_config.input_data_id}_pti_no_comparison_{iid}')
         | 
| 144 | 
            +
                edit_figure_creator.run_experiment(True, False, use_multi_id_training=False, use_wandb=False)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            def edits_for_existed_experiment(run_id):
         | 
| 148 | 
            +
                evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2']
         | 
| 149 | 
            +
                edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True,
         | 
| 150 | 
            +
                                                     run_id=run_id)
         | 
| 151 | 
            +
                edit_figure_creator.run_experiment(False, True, use_multi_id_training=False, use_wandb=False)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            if __name__ == '__main__':
         | 
| 155 | 
            +
                iid = ''.join(choice(ascii_uppercase) for i in range(7))
         | 
| 156 | 
            +
                pti_no_comparison(iid)
         | 
    	
        PTI/models/StyleCLIP/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        PTI/models/StyleCLIP/criteria/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        PTI/models/StyleCLIP/criteria/clip_loss.py
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import clip
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class CLIPLoss(torch.nn.Module):
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                def __init__(self, opts):
         | 
| 9 | 
            +
                    super(CLIPLoss, self).__init__()
         | 
| 10 | 
            +
                    self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
         | 
| 11 | 
            +
                    self.upsample = torch.nn.Upsample(scale_factor=7)
         | 
| 12 | 
            +
                    self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def forward(self, image, text):
         | 
| 15 | 
            +
                    image = self.avg_pool(self.upsample(image))
         | 
| 16 | 
            +
                    similarity = 1 - self.model(image, text)[0] / 100
         | 
| 17 | 
            +
                    return similarity
         | 
    	
        PTI/models/StyleCLIP/criteria/id_loss.py
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from models.facial_recognition.model_irse import Backbone
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class IDLoss(nn.Module):
         | 
| 8 | 
            +
                def __init__(self, opts):
         | 
| 9 | 
            +
                    super(IDLoss, self).__init__()
         | 
| 10 | 
            +
                    print('Loading ResNet ArcFace')
         | 
| 11 | 
            +
                    self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
         | 
| 12 | 
            +
                    self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
         | 
| 13 | 
            +
                    self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
         | 
| 14 | 
            +
                    self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
         | 
| 15 | 
            +
                    self.facenet.eval()
         | 
| 16 | 
            +
                    self.opts = opts
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def extract_feats(self, x):
         | 
| 19 | 
            +
                    if x.shape[2] != 256:
         | 
| 20 | 
            +
                        x = self.pool(x)
         | 
| 21 | 
            +
                    x = x[:, :, 35:223, 32:220]  # Crop interesting region
         | 
| 22 | 
            +
                    x = self.face_pool(x)
         | 
| 23 | 
            +
                    x_feats = self.facenet(x)
         | 
| 24 | 
            +
                    return x_feats
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def forward(self, y_hat, y):
         | 
| 27 | 
            +
                    n_samples = y.shape[0]
         | 
| 28 | 
            +
                    y_feats = self.extract_feats(y)  # Otherwise use the feature from there
         | 
| 29 | 
            +
                    y_hat_feats = self.extract_feats(y_hat)
         | 
| 30 | 
            +
                    y_feats = y_feats.detach()
         | 
| 31 | 
            +
                    loss = 0
         | 
| 32 | 
            +
                    sim_improvement = 0
         | 
| 33 | 
            +
                    count = 0
         | 
| 34 | 
            +
                    for i in range(n_samples):
         | 
| 35 | 
            +
                        diff_target = y_hat_feats[i].dot(y_feats[i])
         | 
| 36 | 
            +
                        loss += 1 - diff_target
         | 
| 37 | 
            +
                        count += 1
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    return loss / count, sim_improvement / count
         | 
    	
        PTI/models/StyleCLIP/global_directions/GUI.py
    ADDED
    
    | @@ -0,0 +1,103 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
             | 
| 3 | 
            +
            from tkinter import Tk,Frame ,Label,Button,messagebox,Canvas,Text,Scale
         | 
| 4 | 
            +
            from tkinter import  HORIZONTAL
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            class View():
         | 
| 7 | 
            +
                def __init__(self,master):
         | 
| 8 | 
            +
                    
         | 
| 9 | 
            +
                    self.width=600
         | 
| 10 | 
            +
                    self.height=600
         | 
| 11 | 
            +
                    
         | 
| 12 | 
            +
                    
         | 
| 13 | 
            +
                    self.root=master
         | 
| 14 | 
            +
                    self.root.geometry("600x600")
         | 
| 15 | 
            +
                    
         | 
| 16 | 
            +
                    self.left_frame=Frame(self.root,width=600)
         | 
| 17 | 
            +
                    self.left_frame.pack_propagate(0)
         | 
| 18 | 
            +
                    self.left_frame.pack(fill='both', side='left', expand='True')
         | 
| 19 | 
            +
                    
         | 
| 20 | 
            +
                    self.retrieval_frame=Frame(self.root,bg='snow3')
         | 
| 21 | 
            +
                    self.retrieval_frame.pack_propagate(0)
         | 
| 22 | 
            +
                    self.retrieval_frame.pack(fill='both', side='right', expand='True')
         | 
| 23 | 
            +
                    
         | 
| 24 | 
            +
                    self.bg_frame=Frame(self.left_frame,bg='snow3',height=600,width=600)
         | 
| 25 | 
            +
                    self.bg_frame.pack_propagate(0)
         | 
| 26 | 
            +
                    self.bg_frame.pack(fill='both', side='top', expand='True')
         | 
| 27 | 
            +
                    
         | 
| 28 | 
            +
                    self.command_frame=Frame(self.left_frame,bg='snow3')
         | 
| 29 | 
            +
                    self.command_frame.pack_propagate(0)
         | 
| 30 | 
            +
                    self.command_frame.pack(fill='both', side='bottom', expand='True')
         | 
| 31 | 
            +
            #        self.command_frame.grid(row=1, column=0,padx=0, pady=0)
         | 
| 32 | 
            +
                    
         | 
| 33 | 
            +
                    self.bg=Canvas(self.bg_frame,width=self.width,height=self.height, bg='gray')
         | 
| 34 | 
            +
                    self.bg.place(relx=0.5, rely=0.5, anchor='center')
         | 
| 35 | 
            +
                    
         | 
| 36 | 
            +
                    self.mani=Canvas(self.retrieval_frame,width=1024,height=1024, bg='gray') 
         | 
| 37 | 
            +
                    self.mani.grid(row=0, column=0,padx=0, pady=42)
         | 
| 38 | 
            +
                    
         | 
| 39 | 
            +
                    self.SetCommand()
         | 
| 40 | 
            +
                    
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
                    
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                def run(self):
         | 
| 45 | 
            +
                    self.root.mainloop()
         | 
| 46 | 
            +
                
         | 
| 47 | 
            +
                def helloCallBack(self):
         | 
| 48 | 
            +
                    category=self.set_category.get()
         | 
| 49 | 
            +
                    messagebox.showinfo( "Hello Python",category)
         | 
| 50 | 
            +
                
         | 
| 51 | 
            +
                def SetCommand(self):
         | 
| 52 | 
            +
                    
         | 
| 53 | 
            +
                    tmp = Label(self.command_frame, text="neutral", width=10 ,bg='snow3')
         | 
| 54 | 
            +
                    tmp.grid(row=1, column=0,padx=10, pady=10)
         | 
| 55 | 
            +
                    
         | 
| 56 | 
            +
                    tmp = Label(self.command_frame, text="a photo of a", width=10 ,bg='snow3')
         | 
| 57 | 
            +
                    tmp.grid(row=1, column=1,padx=10, pady=10)
         | 
| 58 | 
            +
                    
         | 
| 59 | 
            +
                    self.neutral = Text ( self.command_frame, height=2, width=30)
         | 
| 60 | 
            +
                    self.neutral.grid(row=1, column=2,padx=10, pady=10)
         | 
| 61 | 
            +
                    
         | 
| 62 | 
            +
                    
         | 
| 63 | 
            +
                    tmp = Label(self.command_frame, text="target", width=10 ,bg='snow3')
         | 
| 64 | 
            +
                    tmp.grid(row=2, column=0,padx=10, pady=10)
         | 
| 65 | 
            +
                    
         | 
| 66 | 
            +
                    tmp = Label(self.command_frame, text="a photo of a", width=10 ,bg='snow3')
         | 
| 67 | 
            +
                    tmp.grid(row=2, column=1,padx=10, pady=10)
         | 
| 68 | 
            +
                    
         | 
| 69 | 
            +
                    self.target = Text ( self.command_frame, height=2, width=30)
         | 
| 70 | 
            +
                    self.target.grid(row=2, column=2,padx=10, pady=10)
         | 
| 71 | 
            +
                    
         | 
| 72 | 
            +
                    tmp = Label(self.command_frame, text="strength", width=10 ,bg='snow3')
         | 
| 73 | 
            +
                    tmp.grid(row=3, column=0,padx=10, pady=10)
         | 
| 74 | 
            +
                    
         | 
| 75 | 
            +
                    self.alpha = Scale(self.command_frame, from_=-15, to=25, orient=HORIZONTAL,bg='snow3', length=250,resolution=0.01)
         | 
| 76 | 
            +
                    self.alpha.grid(row=3, column=2,padx=10, pady=10)
         | 
| 77 | 
            +
                    
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    tmp = Label(self.command_frame, text="disentangle", width=10 ,bg='snow3')
         | 
| 80 | 
            +
                    tmp.grid(row=4, column=0,padx=10, pady=10)
         | 
| 81 | 
            +
                    
         | 
| 82 | 
            +
                    self.beta = Scale(self.command_frame, from_=0.08, to=0.4, orient=HORIZONTAL,bg='snow3', length=250,resolution=0.001)
         | 
| 83 | 
            +
                    self.beta.grid(row=4, column=2,padx=10, pady=10)
         | 
| 84 | 
            +
                    
         | 
| 85 | 
            +
                    self.reset = Button(self.command_frame, text='Reset') 
         | 
| 86 | 
            +
                    self.reset.grid(row=5, column=1,padx=10, pady=10)
         | 
| 87 | 
            +
                    
         | 
| 88 | 
            +
                    
         | 
| 89 | 
            +
                    self.set_init = Button(self.command_frame, text='Accept') 
         | 
| 90 | 
            +
                    self.set_init.grid(row=5, column=2,padx=10, pady=10)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            #%%
         | 
| 93 | 
            +
            if __name__ == "__main__":
         | 
| 94 | 
            +
                master=Tk()
         | 
| 95 | 
            +
                self=View(master)
         | 
| 96 | 
            +
                self.run()
         | 
| 97 | 
            +
                
         | 
| 98 | 
            +
                
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                
         | 
| 101 | 
            +
                
         | 
| 102 | 
            +
                
         | 
| 103 | 
            +
                
         | 
    	
        PTI/models/StyleCLIP/global_directions/GenerateImg.py
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
            from manipulate import Manipulator
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            #%%
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            if __name__ == "__main__":
         | 
| 11 | 
            +
                parser = argparse.ArgumentParser(description='Process some integers.')
         | 
| 12 | 
            +
                
         | 
| 13 | 
            +
                parser.add_argument('--dataset_name',type=str,default='ffhq',
         | 
| 14 | 
            +
                                help='name of dataset, for example, ffhq')
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                args = parser.parse_args()
         | 
| 17 | 
            +
                dataset_name=args.dataset_name
         | 
| 18 | 
            +
                
         | 
| 19 | 
            +
                if not os.path.isdir('./data/'+dataset_name):
         | 
| 20 | 
            +
                    os.system('mkdir ./data/'+dataset_name)
         | 
| 21 | 
            +
                #%%
         | 
| 22 | 
            +
                M=Manipulator(dataset_name=dataset_name)
         | 
| 23 | 
            +
                np.set_printoptions(suppress=True)
         | 
| 24 | 
            +
                print(M.dataset_name)
         | 
| 25 | 
            +
                #%%
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                M.img_index=0
         | 
| 28 | 
            +
                M.num_images=50
         | 
| 29 | 
            +
                M.alpha=[0]
         | 
| 30 | 
            +
                M.step=1
         | 
| 31 | 
            +
                lindex,bname=0,0
         | 
| 32 | 
            +
                
         | 
| 33 | 
            +
                M.manipulate_layers=[lindex]
         | 
| 34 | 
            +
                codes,out=M.EditOneC(bname)
         | 
| 35 | 
            +
                #%%
         | 
| 36 | 
            +
                
         | 
| 37 | 
            +
                for i in range(len(out)):
         | 
| 38 | 
            +
                    img=out[i,0]
         | 
| 39 | 
            +
                    img=Image.fromarray(img)
         | 
| 40 | 
            +
                    img.save('./data/'+dataset_name+'/'+str(i)+'.jpg')
         | 
| 41 | 
            +
                #%%
         | 
| 42 | 
            +
                w=np.load('./npy/'+dataset_name+'/W.npy')
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                tmp=w[:M.num_images]
         | 
| 45 | 
            +
                tmp=tmp[:,None,:]
         | 
| 46 | 
            +
                tmp=np.tile(tmp,(1,M.Gs.components.synthesis.input_shape[1],1))
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
                np.save('./data/'+dataset_name+'/w_plus.npy',tmp)
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                
         | 
    	
        PTI/models/StyleCLIP/global_directions/GetCode.py
    ADDED
    
    | @@ -0,0 +1,232 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import pickle
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from dnnlib import tflib  
         | 
| 8 | 
            +
            import tensorflow as tf 
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import argparse
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def LoadModel(dataset_name):
         | 
| 13 | 
            +
                # Initialize TensorFlow.
         | 
| 14 | 
            +
                tflib.init_tf()
         | 
| 15 | 
            +
                model_path='./model/'
         | 
| 16 | 
            +
                model_name=dataset_name+'.pkl'
         | 
| 17 | 
            +
                
         | 
| 18 | 
            +
                tmp=os.path.join(model_path,model_name)
         | 
| 19 | 
            +
                with open(tmp, 'rb') as f:
         | 
| 20 | 
            +
                    _, _, Gs = pickle.load(f)
         | 
| 21 | 
            +
                return Gs
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            def lerp(a,b,t):
         | 
| 24 | 
            +
                 return a + (b - a) * t
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            #stylegan-ada
         | 
| 27 | 
            +
            def SelectName(layer_name,suffix):
         | 
| 28 | 
            +
                if suffix==None:
         | 
| 29 | 
            +
                    tmp1='add:0' in layer_name 
         | 
| 30 | 
            +
                    tmp2='shape=(?,' in layer_name
         | 
| 31 | 
            +
                    tmp4='G_synthesis_1' in layer_name
         | 
| 32 | 
            +
                    tmp= tmp1 and tmp2 and tmp4  
         | 
| 33 | 
            +
                else:
         | 
| 34 | 
            +
                    tmp1=('/Conv0_up'+suffix) in layer_name 
         | 
| 35 | 
            +
                    tmp2=('/Conv1'+suffix) in layer_name 
         | 
| 36 | 
            +
                    tmp3=('4x4/Conv'+suffix) in layer_name 
         | 
| 37 | 
            +
                    tmp4='G_synthesis_1' in layer_name
         | 
| 38 | 
            +
                    tmp5=('/ToRGB'+suffix) in layer_name
         | 
| 39 | 
            +
                    tmp= (tmp1 or tmp2 or tmp3 or tmp5) and tmp4 
         | 
| 40 | 
            +
                return tmp
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def GetSNames(suffix):
         | 
| 44 | 
            +
                #get style tensor name 
         | 
| 45 | 
            +
                with tf.Session() as sess:
         | 
| 46 | 
            +
                    op = sess.graph.get_operations()
         | 
| 47 | 
            +
                layers=[m.values() for m in op]
         | 
| 48 | 
            +
                
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                select_layers=[]
         | 
| 51 | 
            +
                for layer in layers:
         | 
| 52 | 
            +
                    layer_name=str(layer)
         | 
| 53 | 
            +
                    if SelectName(layer_name,suffix):
         | 
| 54 | 
            +
                        select_layers.append(layer[0])
         | 
| 55 | 
            +
                return select_layers
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            def SelectName2(layer_name):
         | 
| 58 | 
            +
                tmp1='mod_bias' in layer_name 
         | 
| 59 | 
            +
                tmp2='mod_weight' in layer_name
         | 
| 60 | 
            +
                tmp3='ToRGB' in layer_name 
         | 
| 61 | 
            +
                
         | 
| 62 | 
            +
                tmp= (tmp1 or tmp2) and (not tmp3) 
         | 
| 63 | 
            +
                return tmp
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            def GetKName(Gs):
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
                layers=[var for name, var in Gs.components.synthesis.vars.items()]
         | 
| 68 | 
            +
                
         | 
| 69 | 
            +
                select_layers=[]
         | 
| 70 | 
            +
                for layer in layers:
         | 
| 71 | 
            +
                    layer_name=str(layer)
         | 
| 72 | 
            +
                    if SelectName2(layer_name):
         | 
| 73 | 
            +
                        select_layers.append(layer)
         | 
| 74 | 
            +
                return select_layers
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            def GetCode(Gs,random_state,num_img,num_once,dataset_name):
         | 
| 77 | 
            +
                rnd = np.random.RandomState(random_state)  #5
         | 
| 78 | 
            +
                
         | 
| 79 | 
            +
                truncation_psi=0.7
         | 
| 80 | 
            +
                truncation_cutoff=8
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                dlatent_avg=Gs.get_var('dlatent_avg')
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
                dlatents=np.zeros((num_img,512),dtype='float32')
         | 
| 85 | 
            +
                for i in range(int(num_img/num_once)):
         | 
| 86 | 
            +
                    src_latents =  rnd.randn(num_once, Gs.input_shape[1])
         | 
| 87 | 
            +
                    src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
         | 
| 88 | 
            +
                    
         | 
| 89 | 
            +
                    # Apply truncation trick.
         | 
| 90 | 
            +
                    if truncation_psi is not None and truncation_cutoff is not None:
         | 
| 91 | 
            +
                            layer_idx = np.arange(src_dlatents.shape[1])[np.newaxis, :, np.newaxis]
         | 
| 92 | 
            +
                            ones = np.ones(layer_idx.shape, dtype=np.float32)
         | 
| 93 | 
            +
                            coefs = np.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones)
         | 
| 94 | 
            +
                            src_dlatents_np=lerp(dlatent_avg, src_dlatents, coefs)
         | 
| 95 | 
            +
                            src_dlatents=src_dlatents_np[:,0,:].astype('float32')
         | 
| 96 | 
            +
                            dlatents[(i*num_once):((i+1)*num_once),:]=src_dlatents
         | 
| 97 | 
            +
                print('get all z and w')
         | 
| 98 | 
            +
                
         | 
| 99 | 
            +
                tmp='./npy/'+dataset_name+'/W'
         | 
| 100 | 
            +
                np.save(tmp,dlatents)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                
         | 
| 103 | 
            +
            def GetImg(Gs,num_img,num_once,dataset_name,save_name='images'):
         | 
| 104 | 
            +
                print('Generate Image')
         | 
| 105 | 
            +
                tmp='./npy/'+dataset_name+'/W.npy'
         | 
| 106 | 
            +
                dlatents=np.load(tmp) 
         | 
| 107 | 
            +
                fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
         | 
| 108 | 
            +
                
         | 
| 109 | 
            +
                all_images=[]
         | 
| 110 | 
            +
                for i in range(int(num_img/num_once)):
         | 
| 111 | 
            +
                    print(i)
         | 
| 112 | 
            +
                    images=[]
         | 
| 113 | 
            +
                    for k in range(num_once):
         | 
| 114 | 
            +
                        tmp=dlatents[i*num_once+k]
         | 
| 115 | 
            +
                        tmp=tmp[None,None,:]
         | 
| 116 | 
            +
                        tmp=np.tile(tmp,(1,Gs.components.synthesis.input_shape[1],1))
         | 
| 117 | 
            +
                        image2= Gs.components.synthesis.run(tmp, randomize_noise=False, output_transform=fmt)
         | 
| 118 | 
            +
                        images.append(image2)
         | 
| 119 | 
            +
                        
         | 
| 120 | 
            +
                    images=np.concatenate(images)
         | 
| 121 | 
            +
                    
         | 
| 122 | 
            +
                    all_images.append(images)
         | 
| 123 | 
            +
                    
         | 
| 124 | 
            +
                all_images=np.concatenate(all_images)
         | 
| 125 | 
            +
                
         | 
| 126 | 
            +
                tmp='./npy/'+dataset_name+'/'+save_name
         | 
| 127 | 
            +
                np.save(tmp,all_images)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
            def GetS(dataset_name,num_img):
         | 
| 130 | 
            +
                print('Generate S')
         | 
| 131 | 
            +
                tmp='./npy/'+dataset_name+'/W.npy'
         | 
| 132 | 
            +
                dlatents=np.load(tmp)[:num_img]
         | 
| 133 | 
            +
                
         | 
| 134 | 
            +
                with tf.Session() as sess:
         | 
| 135 | 
            +
                    init = tf.global_variables_initializer()
         | 
| 136 | 
            +
                    sess.run(init)
         | 
| 137 | 
            +
                    
         | 
| 138 | 
            +
                    Gs=LoadModel(dataset_name)
         | 
| 139 | 
            +
                    Gs.print_layers()  #for ada
         | 
| 140 | 
            +
                    select_layers1=GetSNames(suffix=None)  #None,'/mul_1:0','/mod_weight/read:0','/MatMul:0'
         | 
| 141 | 
            +
                    dlatents=dlatents[:,None,:]
         | 
| 142 | 
            +
                    dlatents=np.tile(dlatents,(1,Gs.components.synthesis.input_shape[1],1))
         | 
| 143 | 
            +
                    
         | 
| 144 | 
            +
                    all_s = sess.run(
         | 
| 145 | 
            +
                        select_layers1,
         | 
| 146 | 
            +
                        feed_dict={'G_synthesis_1/dlatents_in:0': dlatents})
         | 
| 147 | 
            +
                
         | 
| 148 | 
            +
                layer_names=[layer.name for layer in select_layers1]
         | 
| 149 | 
            +
                save_tmp=[layer_names,all_s]
         | 
| 150 | 
            +
                return save_tmp
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False):
         | 
| 156 | 
            +
                """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
         | 
| 157 | 
            +
                Can be used as an output transformation for Network.run().
         | 
| 158 | 
            +
                """
         | 
| 159 | 
            +
                if nchw_to_nhwc:
         | 
| 160 | 
            +
                    images = np.transpose(images, [0, 2, 3, 1])
         | 
| 161 | 
            +
                
         | 
| 162 | 
            +
                scale = 255 / (drange[1] - drange[0])
         | 
| 163 | 
            +
                images = images * scale + (0.5 - drange[0] * scale)
         | 
| 164 | 
            +
                
         | 
| 165 | 
            +
                np.clip(images, 0, 255, out=images)
         | 
| 166 | 
            +
                images=images.astype('uint8')
         | 
| 167 | 
            +
                return images
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            def GetCodeMS(dlatents):
         | 
| 171 | 
            +
                    m=[]
         | 
| 172 | 
            +
                    std=[]
         | 
| 173 | 
            +
                    for i in range(len(dlatents)):
         | 
| 174 | 
            +
                        tmp= dlatents[i] 
         | 
| 175 | 
            +
                        tmp_mean=tmp.mean(axis=0)
         | 
| 176 | 
            +
                        tmp_std=tmp.std(axis=0)
         | 
| 177 | 
            +
                        m.append(tmp_mean)
         | 
| 178 | 
            +
                        std.append(tmp_std)
         | 
| 179 | 
            +
                    return m,std
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            #%%
         | 
| 184 | 
            +
            if __name__ == "__main__":
         | 
| 185 | 
            +
                
         | 
| 186 | 
            +
                
         | 
| 187 | 
            +
                parser = argparse.ArgumentParser(description='Process some integers.')
         | 
| 188 | 
            +
                
         | 
| 189 | 
            +
                parser.add_argument('--dataset_name',type=str,default='ffhq',
         | 
| 190 | 
            +
                                help='name of dataset, for example, ffhq')
         | 
| 191 | 
            +
                parser.add_argument('--code_type',choices=['w','s','s_mean_std'],default='w')
         | 
| 192 | 
            +
                
         | 
| 193 | 
            +
                args = parser.parse_args()
         | 
| 194 | 
            +
                random_state=5
         | 
| 195 | 
            +
                num_img=100_000 
         | 
| 196 | 
            +
                num_once=1_000
         | 
| 197 | 
            +
                dataset_name=args.dataset_name
         | 
| 198 | 
            +
                
         | 
| 199 | 
            +
                if not os.path.isfile('./model/'+dataset_name+'.pkl'):
         | 
| 200 | 
            +
                    url='https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/'
         | 
| 201 | 
            +
                    name='stylegan2-'+dataset_name+'-config-f.pkl'
         | 
| 202 | 
            +
                    os.system('wget ' +url+name + '  -P  ./model/')
         | 
| 203 | 
            +
                    os.system('mv ./model/'+name+' ./model/'+dataset_name+'.pkl')
         | 
| 204 | 
            +
                
         | 
| 205 | 
            +
                if not os.path.isdir('./npy/'+dataset_name):
         | 
| 206 | 
            +
                    os.system('mkdir ./npy/'+dataset_name)
         | 
| 207 | 
            +
                
         | 
| 208 | 
            +
                if args.code_type=='w':
         | 
| 209 | 
            +
                    Gs=LoadModel(dataset_name=dataset_name)
         | 
| 210 | 
            +
                    GetCode(Gs,random_state,num_img,num_once,dataset_name)
         | 
| 211 | 
            +
            #        GetImg(Gs,num_img=num_img,num_once=num_once,dataset_name=dataset_name,save_name='images_100K') #no need 
         | 
| 212 | 
            +
                elif args.code_type=='s':
         | 
| 213 | 
            +
                    save_name='S'
         | 
| 214 | 
            +
                    save_tmp=GetS(dataset_name,num_img=2_000)
         | 
| 215 | 
            +
                    tmp='./npy/'+dataset_name+'/'+save_name
         | 
| 216 | 
            +
                    with open(tmp, "wb") as fp:
         | 
| 217 | 
            +
                        pickle.dump(save_tmp, fp)
         | 
| 218 | 
            +
                    
         | 
| 219 | 
            +
                elif args.code_type=='s_mean_std':
         | 
| 220 | 
            +
                    save_tmp=GetS(dataset_name,num_img=num_img)
         | 
| 221 | 
            +
                    dlatents=save_tmp[1]
         | 
| 222 | 
            +
                    m,std=GetCodeMS(dlatents)
         | 
| 223 | 
            +
                    save_tmp=[m,std]
         | 
| 224 | 
            +
                    save_name='S_mean_std'
         | 
| 225 | 
            +
                    tmp='./npy/'+dataset_name+'/'+save_name
         | 
| 226 | 
            +
                    with open(tmp, "wb") as fp:
         | 
| 227 | 
            +
                        pickle.dump(save_tmp, fp)
         | 
| 228 | 
            +
                
         | 
| 229 | 
            +
                
         | 
| 230 | 
            +
                
         | 
| 231 | 
            +
                
         | 
| 232 | 
            +
                
         | 
    	
        PTI/models/StyleCLIP/global_directions/GetGUIData.py
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
            from manipulate import Manipulator
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            #%%
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            if __name__ == "__main__":
         | 
| 11 | 
            +
                parser = argparse.ArgumentParser(description='Process some integers.')
         | 
| 12 | 
            +
                
         | 
| 13 | 
            +
                parser.add_argument('--dataset_name',type=str,default='ffhq',
         | 
| 14 | 
            +
                                help='name of dataset, for example, ffhq')
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                parser.add_argument('--real', action='store_true')
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                args = parser.parse_args()
         | 
| 19 | 
            +
                dataset_name=args.dataset_name
         | 
| 20 | 
            +
                
         | 
| 21 | 
            +
                if not os.path.isdir('./data/'+dataset_name):
         | 
| 22 | 
            +
                    os.system('mkdir ./data/'+dataset_name)
         | 
| 23 | 
            +
                #%%
         | 
| 24 | 
            +
                M=Manipulator(dataset_name=dataset_name)
         | 
| 25 | 
            +
                np.set_printoptions(suppress=True)
         | 
| 26 | 
            +
                print(M.dataset_name)
         | 
| 27 | 
            +
                #%%
         | 
| 28 | 
            +
                #remove all .jpg
         | 
| 29 | 
            +
                names=os.listdir('./data/'+dataset_name+'/')
         | 
| 30 | 
            +
                for name in names:
         | 
| 31 | 
            +
                    if '.jpg' in name:
         | 
| 32 | 
            +
                        os.system('rm ./data/'+dataset_name+'/'+name)
         | 
| 33 | 
            +
                
         | 
| 34 | 
            +
                
         | 
| 35 | 
            +
                #%%
         | 
| 36 | 
            +
                if args.real:
         | 
| 37 | 
            +
                    latents=torch.load('./data/'+dataset_name+'/latents.pt')
         | 
| 38 | 
            +
                    w_plus=latents.cpu().detach().numpy()
         | 
| 39 | 
            +
                else:
         | 
| 40 | 
            +
                    w=np.load('./npy/'+dataset_name+'/W.npy')
         | 
| 41 | 
            +
                    tmp=w[:50] #only use 50 images
         | 
| 42 | 
            +
                    tmp=tmp[:,None,:]
         | 
| 43 | 
            +
                    w_plus=np.tile(tmp,(1,M.Gs.components.synthesis.input_shape[1],1))
         | 
| 44 | 
            +
                np.save('./data/'+dataset_name+'/w_plus.npy',w_plus)
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                #%%
         | 
| 47 | 
            +
                tmp=M.W2S(w_plus)
         | 
| 48 | 
            +
                M.dlatents=tmp
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                M.img_index=0
         | 
| 51 | 
            +
                M.num_images=len(w_plus)
         | 
| 52 | 
            +
                M.alpha=[0]
         | 
| 53 | 
            +
                M.step=1
         | 
| 54 | 
            +
                lindex,bname=0,0
         | 
| 55 | 
            +
                
         | 
| 56 | 
            +
                M.manipulate_layers=[lindex]
         | 
| 57 | 
            +
                codes,out=M.EditOneC(bname)
         | 
| 58 | 
            +
                #%%
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                for i in range(len(out)):
         | 
| 61 | 
            +
                    img=out[i,0]
         | 
| 62 | 
            +
                    img=Image.fromarray(img)
         | 
| 63 | 
            +
                    img.save('./data/'+dataset_name+'/'+str(i)+'.jpg')
         | 
| 64 | 
            +
                #%%
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                
         | 
| 67 | 
            +
                
         | 
    	
        PTI/models/StyleCLIP/global_directions/Inference.py
    ADDED
    
    | @@ -0,0 +1,106 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
             | 
| 3 | 
            +
            from manipulate import Manipulator
         | 
| 4 | 
            +
            import tensorflow as tf
         | 
| 5 | 
            +
            import numpy as np 
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import clip
         | 
| 8 | 
            +
            from MapTS import GetBoundary,GetDt
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            class StyleCLIP():
         | 
| 11 | 
            +
                
         | 
| 12 | 
            +
                def __init__(self,dataset_name='ffhq'):
         | 
| 13 | 
            +
                    print('load clip')
         | 
| 14 | 
            +
                    device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 15 | 
            +
                    self.model, preprocess = clip.load("ViT-B/32", device=device)
         | 
| 16 | 
            +
                    self.LoadData(dataset_name)
         | 
| 17 | 
            +
                    
         | 
| 18 | 
            +
                def LoadData(self, dataset_name):
         | 
| 19 | 
            +
                    tf.keras.backend.clear_session()
         | 
| 20 | 
            +
                    M=Manipulator(dataset_name=dataset_name)
         | 
| 21 | 
            +
                    np.set_printoptions(suppress=True)
         | 
| 22 | 
            +
                    fs3=np.load('./npy/'+dataset_name+'/fs3.npy')
         | 
| 23 | 
            +
                    
         | 
| 24 | 
            +
                    self.M=M
         | 
| 25 | 
            +
                    self.fs3=fs3
         | 
| 26 | 
            +
                    
         | 
| 27 | 
            +
                    w_plus=np.load('./data/'+dataset_name+'/w_plus.npy')
         | 
| 28 | 
            +
                    self.M.dlatents=M.W2S(w_plus)
         | 
| 29 | 
            +
                    
         | 
| 30 | 
            +
                    if dataset_name=='ffhq':
         | 
| 31 | 
            +
                        self.c_threshold=20
         | 
| 32 | 
            +
                    else:
         | 
| 33 | 
            +
                        self.c_threshold=100
         | 
| 34 | 
            +
                    self.SetInitP()
         | 
| 35 | 
            +
                
         | 
| 36 | 
            +
                def SetInitP(self):
         | 
| 37 | 
            +
                    self.M.alpha=[3]
         | 
| 38 | 
            +
                    self.M.num_images=1
         | 
| 39 | 
            +
                    
         | 
| 40 | 
            +
                    self.target=''
         | 
| 41 | 
            +
                    self.neutral=''
         | 
| 42 | 
            +
                    self.GetDt2()
         | 
| 43 | 
            +
                    img_index=0
         | 
| 44 | 
            +
                    self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents]
         | 
| 45 | 
            +
                    
         | 
| 46 | 
            +
                    
         | 
| 47 | 
            +
                def GetDt2(self):
         | 
| 48 | 
            +
                    classnames=[self.target,self.neutral]
         | 
| 49 | 
            +
                    dt=GetDt(classnames,self.model)
         | 
| 50 | 
            +
                    
         | 
| 51 | 
            +
                    self.dt=dt
         | 
| 52 | 
            +
                    num_cs=[]
         | 
| 53 | 
            +
                    betas=np.arange(0.1,0.3,0.01)
         | 
| 54 | 
            +
                    for i in range(len(betas)):
         | 
| 55 | 
            +
                        boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i])
         | 
| 56 | 
            +
                        print(betas[i])
         | 
| 57 | 
            +
                        num_cs.append(num_c)
         | 
| 58 | 
            +
                    
         | 
| 59 | 
            +
                    num_cs=np.array(num_cs)
         | 
| 60 | 
            +
                    select=num_cs>self.c_threshold
         | 
| 61 | 
            +
                    
         | 
| 62 | 
            +
                    if sum(select)==0:
         | 
| 63 | 
            +
                        self.beta=0.1
         | 
| 64 | 
            +
                    else:
         | 
| 65 | 
            +
                        self.beta=betas[select][-1]
         | 
| 66 | 
            +
                    
         | 
| 67 | 
            +
                
         | 
| 68 | 
            +
                def GetCode(self):
         | 
| 69 | 
            +
                    boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta)
         | 
| 70 | 
            +
                    codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2)
         | 
| 71 | 
            +
                    return codes
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                def GetImg(self):
         | 
| 74 | 
            +
                    
         | 
| 75 | 
            +
                    codes=self.GetCode()
         | 
| 76 | 
            +
                    out=self.M.GenerateImg(codes)
         | 
| 77 | 
            +
                    img=out[0,0]
         | 
| 78 | 
            +
                    return img
         | 
| 79 | 
            +
                
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            #%%
         | 
| 84 | 
            +
            if __name__ == "__main__":
         | 
| 85 | 
            +
                style_clip=StyleCLIP()
         | 
| 86 | 
            +
                self=style_clip
         | 
| 87 | 
            +
                
         | 
| 88 | 
            +
                
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
                
         | 
| 91 | 
            +
                
         | 
| 92 | 
            +
                
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                
         | 
| 95 | 
            +
                
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                
         | 
| 98 | 
            +
                
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                
         | 
| 101 | 
            +
                
         | 
| 102 | 
            +
                
         | 
| 103 | 
            +
                
         | 
| 104 | 
            +
                
         | 
| 105 | 
            +
                
         | 
| 106 | 
            +
                
         | 
    	
        PTI/models/StyleCLIP/global_directions/MapTS.py
    ADDED
    
    | @@ -0,0 +1,394 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python3
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            Created on Thu Feb  4 17:36:31 2021
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            @author: wuzongze
         | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import os
         | 
| 10 | 
            +
            #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
         | 
| 11 | 
            +
            #os.environ["CUDA_VISIBLE_DEVICES"] = "1" #(or "1" or "2")
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import sys 
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            #sys.path=['', '/usr/local/tensorflow/avx-avx2-gpu/1.14.0/python3.7/site-packages', '/usr/local/matlab/2018b/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python37.zip', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/lib-dynload', '/usr/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/copkmeans-1.5-py3.7.egg', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/spherecluster-0.1.7-py3.7.egg', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.7/dist-packages', '/usr/lib/python3/dist-packages/IPython/extensions']
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import tensorflow as tf
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import numpy as np 
         | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
            import clip
         | 
| 22 | 
            +
            from PIL import Image
         | 
| 23 | 
            +
            import pickle
         | 
| 24 | 
            +
            import copy
         | 
| 25 | 
            +
            import matplotlib.pyplot as plt
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def GetAlign(out,dt,model,preprocess):
         | 
| 28 | 
            +
                imgs=out
         | 
| 29 | 
            +
                imgs1=imgs.reshape([-1]+list(imgs.shape[2:]))
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
                tmp=[]
         | 
| 32 | 
            +
                for i in range(len(imgs1)):
         | 
| 33 | 
            +
                    
         | 
| 34 | 
            +
                    img=Image.fromarray(imgs1[i])
         | 
| 35 | 
            +
                    image = preprocess(img).unsqueeze(0).to(device)
         | 
| 36 | 
            +
                    tmp.append(image)
         | 
| 37 | 
            +
                
         | 
| 38 | 
            +
                image=torch.cat(tmp)
         | 
| 39 | 
            +
                
         | 
| 40 | 
            +
                with torch.no_grad():
         | 
| 41 | 
            +
                    image_features = model.encode_image(image)
         | 
| 42 | 
            +
                    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                image_features1=image_features.cpu().numpy()
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                image_features1=image_features1.reshape(list(imgs.shape[:2])+[512])
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
                fd=image_features1[:,1:,:]-image_features1[:,:-1,:]
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                fd1=fd.reshape([-1,512])
         | 
| 51 | 
            +
                fd2=fd1/np.linalg.norm(fd1,axis=1)[:,None]
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
                tmp=np.dot(fd2,dt)
         | 
| 54 | 
            +
                m=tmp.mean()
         | 
| 55 | 
            +
                acc=np.sum(tmp>0)/len(tmp)
         | 
| 56 | 
            +
                print(m,acc)
         | 
| 57 | 
            +
                return m,acc
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            def SplitS(ds_p,M,if_std):
         | 
| 61 | 
            +
                all_ds=[]
         | 
| 62 | 
            +
                start=0
         | 
| 63 | 
            +
                for i in M.mindexs:
         | 
| 64 | 
            +
                    tmp=M.dlatents[i].shape[1]
         | 
| 65 | 
            +
                    end=start+tmp
         | 
| 66 | 
            +
                    tmp=ds_p[start:end]
         | 
| 67 | 
            +
            #        tmp=tmp*M.code_std[i]
         | 
| 68 | 
            +
                    
         | 
| 69 | 
            +
                    all_ds.append(tmp)
         | 
| 70 | 
            +
                    start=end
         | 
| 71 | 
            +
                
         | 
| 72 | 
            +
                all_ds2=[]
         | 
| 73 | 
            +
                tmp_index=0
         | 
| 74 | 
            +
                for i in range(len(M.s_names)):
         | 
| 75 | 
            +
                    if (not 'RGB' in M.s_names[i]) and (not len(all_ds[tmp_index])==0):
         | 
| 76 | 
            +
                        
         | 
| 77 | 
            +
            #            tmp=np.abs(all_ds[tmp_index]/M.code_std[i])
         | 
| 78 | 
            +
            #            print(i,tmp.mean())
         | 
| 79 | 
            +
            #            tmp=np.dot(M.latent_codes[i],all_ds[tmp_index])
         | 
| 80 | 
            +
            #            print(tmp)
         | 
| 81 | 
            +
                        if if_std:
         | 
| 82 | 
            +
                            tmp=all_ds[tmp_index]*M.code_std[i]
         | 
| 83 | 
            +
                        else:
         | 
| 84 | 
            +
                            tmp=all_ds[tmp_index]
         | 
| 85 | 
            +
                        
         | 
| 86 | 
            +
                        all_ds2.append(tmp)
         | 
| 87 | 
            +
                        tmp_index+=1
         | 
| 88 | 
            +
                    else:
         | 
| 89 | 
            +
                        tmp=np.zeros(len(M.dlatents[i][0]))
         | 
| 90 | 
            +
                        all_ds2.append(tmp)
         | 
| 91 | 
            +
                return all_ds2
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            imagenet_templates = [
         | 
| 95 | 
            +
                'a bad photo of a {}.',
         | 
| 96 | 
            +
            #    'a photo of many {}.',
         | 
| 97 | 
            +
                'a sculpture of a {}.',
         | 
| 98 | 
            +
                'a photo of the hard to see {}.',
         | 
| 99 | 
            +
                'a low resolution photo of the {}.',
         | 
| 100 | 
            +
                'a rendering of a {}.',
         | 
| 101 | 
            +
                'graffiti of a {}.',
         | 
| 102 | 
            +
                'a bad photo of the {}.',
         | 
| 103 | 
            +
                'a cropped photo of the {}.',
         | 
| 104 | 
            +
                'a tattoo of a {}.',
         | 
| 105 | 
            +
                'the embroidered {}.',
         | 
| 106 | 
            +
                'a photo of a hard to see {}.',
         | 
| 107 | 
            +
                'a bright photo of a {}.',
         | 
| 108 | 
            +
                'a photo of a clean {}.',
         | 
| 109 | 
            +
                'a photo of a dirty {}.',
         | 
| 110 | 
            +
                'a dark photo of the {}.',
         | 
| 111 | 
            +
                'a drawing of a {}.',
         | 
| 112 | 
            +
                'a photo of my {}.',
         | 
| 113 | 
            +
                'the plastic {}.',
         | 
| 114 | 
            +
                'a photo of the cool {}.',
         | 
| 115 | 
            +
                'a close-up photo of a {}.',
         | 
| 116 | 
            +
                'a black and white photo of the {}.',
         | 
| 117 | 
            +
                'a painting of the {}.',
         | 
| 118 | 
            +
                'a painting of a {}.',
         | 
| 119 | 
            +
                'a pixelated photo of the {}.',
         | 
| 120 | 
            +
                'a sculpture of the {}.',
         | 
| 121 | 
            +
                'a bright photo of the {}.',
         | 
| 122 | 
            +
                'a cropped photo of a {}.',
         | 
| 123 | 
            +
                'a plastic {}.',
         | 
| 124 | 
            +
                'a photo of the dirty {}.',
         | 
| 125 | 
            +
                'a jpeg corrupted photo of a {}.',
         | 
| 126 | 
            +
                'a blurry photo of the {}.',
         | 
| 127 | 
            +
                'a photo of the {}.',
         | 
| 128 | 
            +
                'a good photo of the {}.',
         | 
| 129 | 
            +
                'a rendering of the {}.',
         | 
| 130 | 
            +
                'a {} in a video game.',
         | 
| 131 | 
            +
                'a photo of one {}.',
         | 
| 132 | 
            +
                'a doodle of a {}.',
         | 
| 133 | 
            +
                'a close-up photo of the {}.',
         | 
| 134 | 
            +
                'a photo of a {}.',
         | 
| 135 | 
            +
                'the origami {}.',
         | 
| 136 | 
            +
                'the {} in a video game.',
         | 
| 137 | 
            +
                'a sketch of a {}.',
         | 
| 138 | 
            +
                'a doodle of the {}.',
         | 
| 139 | 
            +
                'a origami {}.',
         | 
| 140 | 
            +
                'a low resolution photo of a {}.',
         | 
| 141 | 
            +
                'the toy {}.',
         | 
| 142 | 
            +
                'a rendition of the {}.',
         | 
| 143 | 
            +
                'a photo of the clean {}.',
         | 
| 144 | 
            +
                'a photo of a large {}.',
         | 
| 145 | 
            +
                'a rendition of a {}.',
         | 
| 146 | 
            +
                'a photo of a nice {}.',
         | 
| 147 | 
            +
                'a photo of a weird {}.',
         | 
| 148 | 
            +
                'a blurry photo of a {}.',
         | 
| 149 | 
            +
                'a cartoon {}.',
         | 
| 150 | 
            +
                'art of a {}.',
         | 
| 151 | 
            +
                'a sketch of the {}.',
         | 
| 152 | 
            +
                'a embroidered {}.',
         | 
| 153 | 
            +
                'a pixelated photo of a {}.',
         | 
| 154 | 
            +
                'itap of the {}.',
         | 
| 155 | 
            +
                'a jpeg corrupted photo of the {}.',
         | 
| 156 | 
            +
                'a good photo of a {}.',
         | 
| 157 | 
            +
                'a plushie {}.',
         | 
| 158 | 
            +
                'a photo of the nice {}.',
         | 
| 159 | 
            +
                'a photo of the small {}.',
         | 
| 160 | 
            +
                'a photo of the weird {}.',
         | 
| 161 | 
            +
                'the cartoon {}.',
         | 
| 162 | 
            +
                'art of the {}.',
         | 
| 163 | 
            +
                'a drawing of the {}.',
         | 
| 164 | 
            +
                'a photo of the large {}.',
         | 
| 165 | 
            +
                'a black and white photo of a {}.',
         | 
| 166 | 
            +
                'the plushie {}.',
         | 
| 167 | 
            +
                'a dark photo of a {}.',
         | 
| 168 | 
            +
                'itap of a {}.',
         | 
| 169 | 
            +
                'graffiti of the {}.',
         | 
| 170 | 
            +
                'a toy {}.',
         | 
| 171 | 
            +
                'itap of my {}.',
         | 
| 172 | 
            +
                'a photo of a cool {}.',
         | 
| 173 | 
            +
                'a photo of a small {}.',
         | 
| 174 | 
            +
                'a tattoo of the {}.',
         | 
| 175 | 
            +
            ]
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            def zeroshot_classifier(classnames, templates,model):
         | 
| 179 | 
            +
                with torch.no_grad():
         | 
| 180 | 
            +
                    zeroshot_weights = []
         | 
| 181 | 
            +
                    for classname in classnames:
         | 
| 182 | 
            +
                        texts = [template.format(classname) for template in templates] #format with class
         | 
| 183 | 
            +
                        texts = clip.tokenize(texts).cuda() #tokenize
         | 
| 184 | 
            +
                        class_embeddings = model.encode_text(texts) #embed with text encoder
         | 
| 185 | 
            +
                        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
         | 
| 186 | 
            +
                        class_embedding = class_embeddings.mean(dim=0)
         | 
| 187 | 
            +
                        class_embedding /= class_embedding.norm()
         | 
| 188 | 
            +
                        zeroshot_weights.append(class_embedding)
         | 
| 189 | 
            +
                    zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
         | 
| 190 | 
            +
                return zeroshot_weights
         | 
| 191 | 
            +
             | 
| 192 | 
            +
             | 
| 193 | 
            +
            def GetDt(classnames,model):
         | 
| 194 | 
            +
                text_features=zeroshot_classifier(classnames, imagenet_templates,model).t()
         | 
| 195 | 
            +
                
         | 
| 196 | 
            +
                dt=text_features[0]-text_features[1]
         | 
| 197 | 
            +
                dt=dt.cpu().numpy()
         | 
| 198 | 
            +
                
         | 
| 199 | 
            +
            #    t_m1=t_m/np.linalg.norm(t_m)
         | 
| 200 | 
            +
            #    dt=text_features.cpu().numpy()[0]-t_m1
         | 
| 201 | 
            +
                print(np.linalg.norm(dt))
         | 
| 202 | 
            +
                dt=dt/np.linalg.norm(dt)
         | 
| 203 | 
            +
                return dt
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            def GetBoundary(fs3,dt,M,threshold):
         | 
| 207 | 
            +
                tmp=np.dot(fs3,dt)
         | 
| 208 | 
            +
                
         | 
| 209 | 
            +
                ds_imp=copy.copy(tmp)
         | 
| 210 | 
            +
                select=np.abs(tmp)<threshold
         | 
| 211 | 
            +
                num_c=np.sum(~select)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
             | 
| 214 | 
            +
                ds_imp[select]=0
         | 
| 215 | 
            +
                tmp=np.abs(ds_imp).max()
         | 
| 216 | 
            +
                ds_imp/=tmp
         | 
| 217 | 
            +
                
         | 
| 218 | 
            +
                boundary_tmp2=SplitS(ds_imp,M,if_std=True)
         | 
| 219 | 
            +
                print('num of channels being manipulated:',num_c)
         | 
| 220 | 
            +
                return boundary_tmp2,num_c
         | 
| 221 | 
            +
             | 
| 222 | 
            +
            def GetFs(file_path):
         | 
| 223 | 
            +
                fs=np.load(file_path+'single_channel.npy')
         | 
| 224 | 
            +
                tmp=np.linalg.norm(fs,axis=-1)
         | 
| 225 | 
            +
                fs1=fs/tmp[:,:,:,None]
         | 
| 226 | 
            +
                fs2=fs1[:,:,1,:]-fs1[:,:,0,:]  # 5*sigma - (-5)* sigma
         | 
| 227 | 
            +
                fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
         | 
| 228 | 
            +
                fs3=fs3.mean(axis=1)
         | 
| 229 | 
            +
                fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]
         | 
| 230 | 
            +
                return fs3
         | 
| 231 | 
            +
            #%%
         | 
| 232 | 
            +
             | 
| 233 | 
            +
            if __name__ == "__main__":
         | 
| 234 | 
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 235 | 
            +
                model, preprocess = clip.load("ViT-B/32", device=device)
         | 
| 236 | 
            +
                #%%
         | 
| 237 | 
            +
                sys.path.append('/cs/labs/danix/wuzongze/Gan_Manipulation/play')
         | 
| 238 | 
            +
                from example_try import Manipulator4
         | 
| 239 | 
            +
                
         | 
| 240 | 
            +
                M=Manipulator4(dataset_name='ffhq',code_type='S')
         | 
| 241 | 
            +
                np.set_printoptions(suppress=True)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                #%%
         | 
| 244 | 
            +
                
         | 
| 245 | 
            +
                
         | 
| 246 | 
            +
                file_path='/cs/labs/danix/wuzongze/Tansformer_Manipulation/CLIP/results/'+M.dataset_name+'/'
         | 
| 247 | 
            +
                fs3=GetFs(file_path)
         | 
| 248 | 
            +
                
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                
         | 
| 251 | 
            +
                #%%
         | 
| 252 | 
            +
                '''
         | 
| 253 | 
            +
                text_features=zeroshot_classifier2(classnames, imagenet_templates) #.t()
         | 
| 254 | 
            +
                    
         | 
| 255 | 
            +
                tmp=np.linalg.norm(text_features,axis=2)
         | 
| 256 | 
            +
                text_features/=tmp[:,:,None]
         | 
| 257 | 
            +
                dt=text_features[0]-text_features[1]
         | 
| 258 | 
            +
                
         | 
| 259 | 
            +
                tmp=np.linalg.norm(dt,axis=1)
         | 
| 260 | 
            +
                dt/=tmp[:,None]
         | 
| 261 | 
            +
                dt=dt.mean(axis=0)
         | 
| 262 | 
            +
                '''
         | 
| 263 | 
            +
                
         | 
| 264 | 
            +
                #%%
         | 
| 265 | 
            +
                '''
         | 
| 266 | 
            +
                all_tmp=[]
         | 
| 267 | 
            +
                tmp=torch.load('/cs/labs/danix/wuzongze/downloads/harris_latent.pt')
         | 
| 268 | 
            +
                tmp=tmp.cpu().detach().numpy() #[:,:14,:]
         | 
| 269 | 
            +
                all_tmp.append(tmp)
         | 
| 270 | 
            +
                
         | 
| 271 | 
            +
                tmp=torch.load('/cs/labs/danix/wuzongze/downloads/ariana_latent.pt')
         | 
| 272 | 
            +
                tmp=tmp.cpu().detach().numpy() #[:,:14,:]
         | 
| 273 | 
            +
                all_tmp.append(tmp)
         | 
| 274 | 
            +
                
         | 
| 275 | 
            +
                tmp=torch.load('/cs/labs/danix/wuzongze/downloads/federer.pt')
         | 
| 276 | 
            +
                tmp=tmp.cpu().detach().numpy() #[:,:14,:]
         | 
| 277 | 
            +
                all_tmp.append(tmp)
         | 
| 278 | 
            +
                
         | 
| 279 | 
            +
                all_tmp=np.array(all_tmp)[:,0]
         | 
| 280 | 
            +
                
         | 
| 281 | 
            +
                dlatent_tmp=M.W2S(all_tmp)
         | 
| 282 | 
            +
                '''
         | 
| 283 | 
            +
                '''
         | 
| 284 | 
            +
                tmp=torch.load('/cs/labs/danix/wuzongze/downloads/all_cars.pt')
         | 
| 285 | 
            +
                tmp=tmp.cpu().detach().numpy()[:300]
         | 
| 286 | 
            +
                dlatent_tmp=M.W2S(tmp)
         | 
| 287 | 
            +
                '''
         | 
| 288 | 
            +
                '''
         | 
| 289 | 
            +
                tmp=torch.load('/cs/labs/danix/wuzongze/downloads/faces.pt')
         | 
| 290 | 
            +
                tmp=tmp.cpu().detach().numpy()[:100]
         | 
| 291 | 
            +
                dlatent_tmp=M.W2S(tmp)
         | 
| 292 | 
            +
                '''
         | 
| 293 | 
            +
                #%%
         | 
| 294 | 
            +
            #    M.viz_size=1024
         | 
| 295 | 
            +
                M.img_index=0
         | 
| 296 | 
            +
                M.num_images=30
         | 
| 297 | 
            +
                dlatent_tmp=[tmp[M.img_index:(M.img_index+M.num_images)] for tmp in M.dlatents]
         | 
| 298 | 
            +
                #%%
         | 
| 299 | 
            +
                
         | 
| 300 | 
            +
                classnames=['face','face with glasses']
         | 
| 301 | 
            +
                
         | 
| 302 | 
            +
            #    classnames=['car','classic car']
         | 
| 303 | 
            +
            #    classnames=['dog','happy dog']
         | 
| 304 | 
            +
            #    classnames=['bedroom','modern bedroom']
         | 
| 305 | 
            +
                
         | 
| 306 | 
            +
            #    classnames=['church','church without watermark']
         | 
| 307 | 
            +
            #    classnames=['natural scene','natural scene without grass']
         | 
| 308 | 
            +
                dt=GetDt(classnames,model)
         | 
| 309 | 
            +
            #    tmp=np.dot(fs3,dt)
         | 
| 310 | 
            +
            #    
         | 
| 311 | 
            +
            #    ds_imp=copy.copy(tmp)
         | 
| 312 | 
            +
            #    select=np.abs(tmp)<0.1
         | 
| 313 | 
            +
            #    num_c=np.sum(~select)
         | 
| 314 | 
            +
            #
         | 
| 315 | 
            +
            #
         | 
| 316 | 
            +
            #    ds_imp[select]=0
         | 
| 317 | 
            +
            #    tmp=np.abs(ds_imp).max()
         | 
| 318 | 
            +
            #    ds_imp/=tmp
         | 
| 319 | 
            +
            #    
         | 
| 320 | 
            +
            #    boundary_tmp2=SplitS(ds_imp,M,if_std=True)
         | 
| 321 | 
            +
            #    print('num of channels being manipulated:',num_c)
         | 
| 322 | 
            +
                
         | 
| 323 | 
            +
                boundary_tmp2=GetBoundary(fs3,dt,M,threshold=0.13)
         | 
| 324 | 
            +
                
         | 
| 325 | 
            +
                #%%
         | 
| 326 | 
            +
                M.start_distance=-20
         | 
| 327 | 
            +
                M.end_distance=20
         | 
| 328 | 
            +
                M.step=7
         | 
| 329 | 
            +
            #    M.num_images=100
         | 
| 330 | 
            +
                codes=M.MSCode(dlatent_tmp,boundary_tmp2)
         | 
| 331 | 
            +
                out=M.GenerateImg(codes)
         | 
| 332 | 
            +
                M.Vis2(str('tmp'),'filter2',out)
         | 
| 333 | 
            +
                
         | 
| 334 | 
            +
            #    full=GetAlign(out,dt,model,preprocess)
         | 
| 335 | 
            +
                
         | 
| 336 | 
            +
                
         | 
| 337 | 
            +
                #%%
         | 
| 338 | 
            +
                boundary_tmp3=copy.copy(boundary_tmp2) #primary
         | 
| 339 | 
            +
                boundary_tmp4=copy.copy(boundary_tmp2) #condition
         | 
| 340 | 
            +
                #%%
         | 
| 341 | 
            +
                boundary_tmp2=copy.copy(boundary_tmp3)
         | 
| 342 | 
            +
                for i in range(len(boundary_tmp3)):
         | 
| 343 | 
            +
                    select=boundary_tmp4[i]==0
         | 
| 344 | 
            +
                    boundary_tmp2[i][~select]=0
         | 
| 345 | 
            +
                
         | 
| 346 | 
            +
                
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                
         | 
| 349 | 
            +
                
         | 
| 350 | 
            +
                
         | 
| 351 | 
            +
                
         | 
| 352 | 
            +
                #%%1
         | 
| 353 | 
            +
                
         | 
| 354 | 
            +
                
         | 
| 355 | 
            +
             | 
| 356 | 
            +
             | 
| 357 | 
            +
             | 
| 358 | 
            +
             | 
| 359 | 
            +
             | 
| 360 | 
            +
             | 
| 361 | 
            +
             | 
| 362 | 
            +
             | 
| 363 | 
            +
             | 
| 364 | 
            +
             | 
| 365 | 
            +
             | 
| 366 | 
            +
             | 
| 367 | 
            +
             | 
| 368 | 
            +
             | 
| 369 | 
            +
             | 
| 370 | 
            +
             | 
| 371 | 
            +
             | 
| 372 | 
            +
             | 
| 373 | 
            +
             | 
| 374 | 
            +
             | 
| 375 | 
            +
             | 
| 376 | 
            +
             | 
| 377 | 
            +
             | 
| 378 | 
            +
             | 
| 379 | 
            +
             | 
| 380 | 
            +
             | 
| 381 | 
            +
             | 
| 382 | 
            +
             | 
| 383 | 
            +
             | 
| 384 | 
            +
             | 
| 385 | 
            +
             | 
| 386 | 
            +
             | 
| 387 | 
            +
             | 
| 388 | 
            +
             | 
| 389 | 
            +
             | 
| 390 | 
            +
             | 
| 391 | 
            +
             | 
| 392 | 
            +
             | 
| 393 | 
            +
             | 
| 394 | 
            +
             | 
    	
        PTI/models/StyleCLIP/global_directions/PlayInteractively.py
    ADDED
    
    | @@ -0,0 +1,197 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            from tkinter import Tk 
         | 
| 5 | 
            +
            from PIL import Image, ImageTk
         | 
| 6 | 
            +
            from tkinter.filedialog import askopenfilename
         | 
| 7 | 
            +
            from GUI import View
         | 
| 8 | 
            +
            from Inference import StyleCLIP
         | 
| 9 | 
            +
            import argparse
         | 
| 10 | 
            +
            #%%
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class PlayInteractively():  #Controller
         | 
| 14 | 
            +
                '''
         | 
| 15 | 
            +
                followed Model View Controller Design Pattern
         | 
| 16 | 
            +
                
         | 
| 17 | 
            +
                controller, model, view
         | 
| 18 | 
            +
                '''
         | 
| 19 | 
            +
                def __init__(self,dataset_name='ffhq'):
         | 
| 20 | 
            +
                    
         | 
| 21 | 
            +
                    self.root = Tk()
         | 
| 22 | 
            +
                    self.view=View(self.root)
         | 
| 23 | 
            +
                    self.img_ratio=2
         | 
| 24 | 
            +
                    self.style_clip=StyleCLIP(dataset_name)
         | 
| 25 | 
            +
                    
         | 
| 26 | 
            +
                    self.view.neutral.bind("<Return>", self.text_n)
         | 
| 27 | 
            +
                    self.view.target.bind("<Return>", self.text_t)
         | 
| 28 | 
            +
                    self.view.alpha.bind('<ButtonRelease-1>', self.ChangeAlpha)
         | 
| 29 | 
            +
                    self.view.beta.bind('<ButtonRelease-1>', self.ChangeBeta)
         | 
| 30 | 
            +
                    self.view.set_init.bind('<ButtonPress-1>', self.SetInit) 
         | 
| 31 | 
            +
                    self.view.reset.bind('<ButtonPress-1>', self.Reset) 
         | 
| 32 | 
            +
                    self.view.bg.bind('<Double-1>', self.open_img)
         | 
| 33 | 
            +
                    
         | 
| 34 | 
            +
                    
         | 
| 35 | 
            +
                    self.drawn  = None
         | 
| 36 | 
            +
                    
         | 
| 37 | 
            +
                    self.view.target.delete(1.0, "end")
         | 
| 38 | 
            +
                    self.view.target.insert("end", self.style_clip.target)
         | 
| 39 | 
            +
            #        
         | 
| 40 | 
            +
                    self.view.neutral.delete(1.0, "end")
         | 
| 41 | 
            +
                    self.view.neutral.insert("end", self.style_clip.neutral)
         | 
| 42 | 
            +
                    
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                def Reset(self,event):
         | 
| 45 | 
            +
                    self.style_clip.GetDt2()
         | 
| 46 | 
            +
                    self.style_clip.M.alpha=[0]
         | 
| 47 | 
            +
                    
         | 
| 48 | 
            +
                    self.view.beta.set(self.style_clip.beta)
         | 
| 49 | 
            +
                    self.view.alpha.set(0)
         | 
| 50 | 
            +
                    
         | 
| 51 | 
            +
                    img=self.style_clip.GetImg()
         | 
| 52 | 
            +
                    img=Image.fromarray(img)
         | 
| 53 | 
            +
                    img = ImageTk.PhotoImage(img)
         | 
| 54 | 
            +
                    self.addImage_m(img)
         | 
| 55 | 
            +
                    
         | 
| 56 | 
            +
                
         | 
| 57 | 
            +
                def SetInit(self,event):
         | 
| 58 | 
            +
                    codes=self.style_clip.GetCode()
         | 
| 59 | 
            +
                    self.style_clip.M.dlatent_tmp=[tmp[:,0] for tmp in codes]
         | 
| 60 | 
            +
                    print('set init')
         | 
| 61 | 
            +
                
         | 
| 62 | 
            +
                def ChangeAlpha(self,event):
         | 
| 63 | 
            +
                    tmp=self.view.alpha.get()
         | 
| 64 | 
            +
                    self.style_clip.M.alpha=[float(tmp)]
         | 
| 65 | 
            +
                    
         | 
| 66 | 
            +
                    img=self.style_clip.GetImg()
         | 
| 67 | 
            +
                    print('manipulate one')
         | 
| 68 | 
            +
                    img=Image.fromarray(img)
         | 
| 69 | 
            +
                    img = ImageTk.PhotoImage(img)
         | 
| 70 | 
            +
                    self.addImage_m(img)
         | 
| 71 | 
            +
                    
         | 
| 72 | 
            +
                def ChangeBeta(self,event):
         | 
| 73 | 
            +
                    tmp=self.view.beta.get()
         | 
| 74 | 
            +
                    self.style_clip.beta=float(tmp)
         | 
| 75 | 
            +
                    
         | 
| 76 | 
            +
                    img=self.style_clip.GetImg()
         | 
| 77 | 
            +
                    print('manipulate one')
         | 
| 78 | 
            +
                    img=Image.fromarray(img)
         | 
| 79 | 
            +
                    img = ImageTk.PhotoImage(img)
         | 
| 80 | 
            +
                    self.addImage_m(img)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def ChangeDataset(self,event):
         | 
| 83 | 
            +
                    
         | 
| 84 | 
            +
                    dataset_name=self.view.set_category.get()
         | 
| 85 | 
            +
                    
         | 
| 86 | 
            +
                    self.style_clip.LoadData(dataset_name)
         | 
| 87 | 
            +
                    
         | 
| 88 | 
            +
                    self.view.target.delete(1.0, "end")
         | 
| 89 | 
            +
                    self.view.target.insert("end", self.style_clip.target)
         | 
| 90 | 
            +
                    
         | 
| 91 | 
            +
                    self.view.neutral.delete(1.0, "end")
         | 
| 92 | 
            +
                    self.view.neutral.insert("end", self.style_clip.neutral)
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                def text_t(self,event):
         | 
| 95 | 
            +
                    tmp=self.view.target.get("1.0",'end')
         | 
| 96 | 
            +
                    tmp=tmp.replace('\n','')
         | 
| 97 | 
            +
                    
         | 
| 98 | 
            +
                    self.view.target.delete(1.0, "end")
         | 
| 99 | 
            +
                    self.view.target.insert("end", tmp)
         | 
| 100 | 
            +
                    
         | 
| 101 | 
            +
                    print('target',tmp,'###')
         | 
| 102 | 
            +
                    self.style_clip.target=tmp
         | 
| 103 | 
            +
                    self.style_clip.GetDt2()
         | 
| 104 | 
            +
                    self.view.beta.set(self.style_clip.beta)
         | 
| 105 | 
            +
                    self.view.alpha.set(3)
         | 
| 106 | 
            +
                    self.style_clip.M.alpha=[3]
         | 
| 107 | 
            +
                    
         | 
| 108 | 
            +
                    img=self.style_clip.GetImg()
         | 
| 109 | 
            +
                    print('manipulate one')
         | 
| 110 | 
            +
                    img=Image.fromarray(img)
         | 
| 111 | 
            +
                    img = ImageTk.PhotoImage(img)
         | 
| 112 | 
            +
                    self.addImage_m(img)
         | 
| 113 | 
            +
                    
         | 
| 114 | 
            +
                    
         | 
| 115 | 
            +
                def text_n(self,event):
         | 
| 116 | 
            +
                    tmp=self.view.neutral.get("1.0",'end')
         | 
| 117 | 
            +
                    tmp=tmp.replace('\n','')
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                    self.view.neutral.delete(1.0, "end")
         | 
| 120 | 
            +
                    self.view.neutral.insert("end", tmp)
         | 
| 121 | 
            +
                    
         | 
| 122 | 
            +
                    print('neutral',tmp,'###')
         | 
| 123 | 
            +
                    self.style_clip.neutral=tmp
         | 
| 124 | 
            +
                    self.view.target.delete(1.0, "end")
         | 
| 125 | 
            +
                    self.view.target.insert("end", tmp)
         | 
| 126 | 
            +
                    
         | 
| 127 | 
            +
                    
         | 
| 128 | 
            +
                def run(self):
         | 
| 129 | 
            +
                    self.root.mainloop()
         | 
| 130 | 
            +
                
         | 
| 131 | 
            +
                def addImage(self,img):
         | 
| 132 | 
            +
                    self.view.bg.create_image(self.view.width/2, self.view.height/2, image=img, anchor='center')
         | 
| 133 | 
            +
                    self.image=img #save a copy of image. if not the image will disappear
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                def addImage_m(self,img):
         | 
| 136 | 
            +
                    self.view.mani.create_image(512, 512, image=img, anchor='center')
         | 
| 137 | 
            +
                    self.image2=img
         | 
| 138 | 
            +
                    
         | 
| 139 | 
            +
                
         | 
| 140 | 
            +
                def openfn(self):
         | 
| 141 | 
            +
                    filename = askopenfilename(title='open',initialdir='./data/'+self.style_clip.M.dataset_name+'/',filetypes=[("all image format", ".jpg"),("all image format", ".png")])
         | 
| 142 | 
            +
                    return filename
         | 
| 143 | 
            +
                
         | 
| 144 | 
            +
                def open_img(self,event):
         | 
| 145 | 
            +
                    x = self.openfn()
         | 
| 146 | 
            +
                    print(x)
         | 
| 147 | 
            +
                    
         | 
| 148 | 
            +
                    
         | 
| 149 | 
            +
                    img = Image.open(x)
         | 
| 150 | 
            +
                    img2 = img.resize(( 512,512), Image.ANTIALIAS)
         | 
| 151 | 
            +
                    img2 = ImageTk.PhotoImage(img2)
         | 
| 152 | 
            +
                    self.addImage(img2)
         | 
| 153 | 
            +
                    
         | 
| 154 | 
            +
                    img = ImageTk.PhotoImage(img)
         | 
| 155 | 
            +
                    self.addImage_m(img)
         | 
| 156 | 
            +
                    
         | 
| 157 | 
            +
                    img_index=x.split('/')[-1].split('.')[0]
         | 
| 158 | 
            +
                    img_index=int(img_index)
         | 
| 159 | 
            +
                    print(img_index)
         | 
| 160 | 
            +
                    self.style_clip.M.img_index=img_index
         | 
| 161 | 
            +
                    self.style_clip.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.style_clip.M.dlatents]
         | 
| 162 | 
            +
                    
         | 
| 163 | 
            +
                    
         | 
| 164 | 
            +
                    self.style_clip.GetDt2()
         | 
| 165 | 
            +
                    self.view.beta.set(self.style_clip.beta)
         | 
| 166 | 
            +
                    self.view.alpha.set(3)
         | 
| 167 | 
            +
                    
         | 
| 168 | 
            +
                #%%
         | 
| 169 | 
            +
            if __name__ == "__main__":
         | 
| 170 | 
            +
                parser = argparse.ArgumentParser(description='Process some integers.')
         | 
| 171 | 
            +
                
         | 
| 172 | 
            +
                parser.add_argument('--dataset_name',type=str,default='ffhq',
         | 
| 173 | 
            +
                                help='name of dataset, for example, ffhq')
         | 
| 174 | 
            +
                
         | 
| 175 | 
            +
                args = parser.parse_args()
         | 
| 176 | 
            +
                dataset_name=args.dataset_name
         | 
| 177 | 
            +
                
         | 
| 178 | 
            +
                self=PlayInteractively(dataset_name)
         | 
| 179 | 
            +
                self.run()
         | 
| 180 | 
            +
                
         | 
| 181 | 
            +
                
         | 
| 182 | 
            +
                
         | 
| 183 | 
            +
                
         | 
| 184 | 
            +
                
         | 
| 185 | 
            +
                
         | 
| 186 | 
            +
                
         | 
| 187 | 
            +
                
         | 
| 188 | 
            +
                
         | 
| 189 | 
            +
                
         | 
| 190 | 
            +
                
         | 
| 191 | 
            +
                
         | 
| 192 | 
            +
                
         | 
| 193 | 
            +
                
         | 
| 194 | 
            +
                
         | 
| 195 | 
            +
                
         | 
| 196 | 
            +
                
         | 
| 197 | 
            +
                
         | 
    	
        PTI/models/StyleCLIP/global_directions/SingleChannel.py
    ADDED
    
    | @@ -0,0 +1,109 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np 
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import clip
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            import copy
         | 
| 9 | 
            +
            from manipulate import Manipulator
         | 
| 10 | 
            +
            import argparse
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def GetImgF(out,model,preprocess):
         | 
| 13 | 
            +
                imgs=out
         | 
| 14 | 
            +
                imgs1=imgs.reshape([-1]+list(imgs.shape[2:]))
         | 
| 15 | 
            +
                
         | 
| 16 | 
            +
                tmp=[]
         | 
| 17 | 
            +
                for i in range(len(imgs1)):
         | 
| 18 | 
            +
                    
         | 
| 19 | 
            +
                    img=Image.fromarray(imgs1[i])
         | 
| 20 | 
            +
                    image = preprocess(img).unsqueeze(0).to(device)
         | 
| 21 | 
            +
                    tmp.append(image)
         | 
| 22 | 
            +
                
         | 
| 23 | 
            +
                image=torch.cat(tmp)
         | 
| 24 | 
            +
                with torch.no_grad():
         | 
| 25 | 
            +
                    image_features = model.encode_image(image)
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                image_features1=image_features.cpu().numpy()
         | 
| 28 | 
            +
                image_features1=image_features1.reshape(list(imgs.shape[:2])+[512])
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                return image_features1
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            def GetFs(fs):
         | 
| 33 | 
            +
                tmp=np.linalg.norm(fs,axis=-1)
         | 
| 34 | 
            +
                fs1=fs/tmp[:,:,:,None]
         | 
| 35 | 
            +
                fs2=fs1[:,:,1,:]-fs1[:,:,0,:]  # 5*sigma - (-5)* sigma
         | 
| 36 | 
            +
                fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
         | 
| 37 | 
            +
                fs3=fs3.mean(axis=1)
         | 
| 38 | 
            +
                fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]
         | 
| 39 | 
            +
                return fs3
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            #%%
         | 
| 42 | 
            +
            if __name__ == "__main__":
         | 
| 43 | 
            +
                parser = argparse.ArgumentParser(description='Process some integers.')
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
                parser.add_argument('--dataset_name',type=str,default='cat',
         | 
| 46 | 
            +
                                help='name of dataset, for example, ffhq')
         | 
| 47 | 
            +
                args = parser.parse_args()
         | 
| 48 | 
            +
                dataset_name=args.dataset_name
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                #%%
         | 
| 51 | 
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 52 | 
            +
                model, preprocess = clip.load("ViT-B/32", device=device)
         | 
| 53 | 
            +
                #%%
         | 
| 54 | 
            +
                M=Manipulator(dataset_name=dataset_name)
         | 
| 55 | 
            +
                np.set_printoptions(suppress=True)
         | 
| 56 | 
            +
                print(M.dataset_name)
         | 
| 57 | 
            +
                #%%
         | 
| 58 | 
            +
                img_sindex=0
         | 
| 59 | 
            +
                num_images=100
         | 
| 60 | 
            +
                dlatents_o=[]
         | 
| 61 | 
            +
                tmp=img_sindex*num_images
         | 
| 62 | 
            +
                for i in range(len(M.dlatents)):
         | 
| 63 | 
            +
                    tmp1=M.dlatents[i][tmp:(tmp+num_images)]
         | 
| 64 | 
            +
                    dlatents_o.append(tmp1)
         | 
| 65 | 
            +
                #%%
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
                all_f=[]
         | 
| 68 | 
            +
                M.alpha=[-5,5] #ffhq 5
         | 
| 69 | 
            +
                M.step=2
         | 
| 70 | 
            +
                M.num_images=num_images
         | 
| 71 | 
            +
                select=np.array(M.mindexs)<=16 #below or equal to 128 resolution 
         | 
| 72 | 
            +
                mindexs2=np.array(M.mindexs)[select]
         | 
| 73 | 
            +
                for lindex in mindexs2: #ignore ToRGB layers
         | 
| 74 | 
            +
                    print(lindex)
         | 
| 75 | 
            +
                    num_c=M.dlatents[lindex].shape[1]
         | 
| 76 | 
            +
                    for cindex in range(num_c):
         | 
| 77 | 
            +
                        
         | 
| 78 | 
            +
                        M.dlatents=copy.copy(dlatents_o)
         | 
| 79 | 
            +
                        M.dlatents[lindex][:,cindex]=M.code_mean[lindex][cindex]
         | 
| 80 | 
            +
                        
         | 
| 81 | 
            +
                        M.manipulate_layers=[lindex]
         | 
| 82 | 
            +
                        codes,out=M.EditOneC(cindex) 
         | 
| 83 | 
            +
                        image_features1=GetImgF(out,model,preprocess)
         | 
| 84 | 
            +
                        all_f.append(image_features1)
         | 
| 85 | 
            +
                
         | 
| 86 | 
            +
                all_f=np.array(all_f)
         | 
| 87 | 
            +
                
         | 
| 88 | 
            +
                fs3=GetFs(all_f)
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
                #%%
         | 
| 91 | 
            +
                file_path='./npy/'+M.dataset_name+'/'
         | 
| 92 | 
            +
                np.save(file_path+'fs3',fs3)
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                
         | 
| 95 | 
            +
                
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                
         | 
| 98 | 
            +
                
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                
         | 
| 101 | 
            +
                
         | 
| 102 | 
            +
                
         | 
| 103 | 
            +
                
         | 
| 104 | 
            +
                
         | 
| 105 | 
            +
                
         | 
| 106 | 
            +
                
         | 
| 107 | 
            +
                
         | 
| 108 | 
            +
                
         | 
| 109 | 
            +
                
         | 
    	
        PTI/models/StyleCLIP/global_directions/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        PTI/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:394f0f166305654f49cd1b0cd3d4f2b7a51e740a449a1ebfa1c69f79d01399fa
         | 
| 3 | 
            +
            size 2506880
         | 
    	
        PTI/models/StyleCLIP/global_directions/dnnlib/__init__.py
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         | 
| 4 | 
            +
            # and proprietary rights in and to this software, related documentation
         | 
| 5 | 
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         | 
| 6 | 
            +
            # distribution of this software and related documentation without an express
         | 
| 7 | 
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .util import EasyDict, make_cache_dir_path
         | 
    	
        PTI/models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         | 
| 4 | 
            +
            # and proprietary rights in and to this software, related documentation
         | 
| 5 | 
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         | 
| 6 | 
            +
            # distribution of this software and related documentation without an express
         | 
| 7 | 
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from . import autosummary
         | 
| 10 | 
            +
            from . import network
         | 
| 11 | 
            +
            from . import optimizer
         | 
| 12 | 
            +
            from . import tfutil
         | 
| 13 | 
            +
            from . import custom_ops
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from .tfutil import *
         | 
| 16 | 
            +
            from .network import Network
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .optimizer import Optimizer
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from .custom_ops import get_plugin
         | 
    	
        PTI/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py
    ADDED
    
    | @@ -0,0 +1,193 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         | 
| 4 | 
            +
            # and proprietary rights in and to this software, related documentation
         | 
| 5 | 
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         | 
| 6 | 
            +
            # distribution of this software and related documentation without an express
         | 
| 7 | 
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """Helper for adding automatically tracked values to Tensorboard.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            Autosummary creates an identity op that internally keeps track of the input
         | 
| 12 | 
            +
            values and automatically shows up in TensorBoard. The reported value
         | 
| 13 | 
            +
            represents an average over input components. The average is accumulated
         | 
| 14 | 
            +
            constantly over time and flushed when save_summaries() is called.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            Notes:
         | 
| 17 | 
            +
            - The output tensor must be used as an input for something else in the
         | 
| 18 | 
            +
              graph. Otherwise, the autosummary op will not get executed, and the average
         | 
| 19 | 
            +
              value will not get accumulated.
         | 
| 20 | 
            +
            - It is perfectly fine to include autosummaries with the same name in
         | 
| 21 | 
            +
              several places throughout the graph, even if they are executed concurrently.
         | 
| 22 | 
            +
            - It is ok to also pass in a python scalar or numpy array. In this case, it
         | 
| 23 | 
            +
              is added to the average immediately.
         | 
| 24 | 
            +
            """
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from collections import OrderedDict
         | 
| 27 | 
            +
            import numpy as np
         | 
| 28 | 
            +
            import tensorflow as tf
         | 
| 29 | 
            +
            from tensorboard import summary as summary_lib
         | 
| 30 | 
            +
            from tensorboard.plugins.custom_scalar import layout_pb2
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            from . import tfutil
         | 
| 33 | 
            +
            from .tfutil import TfExpression
         | 
| 34 | 
            +
            from .tfutil import TfExpressionEx
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # Enable "Custom scalars" tab in TensorBoard for advanced formatting.
         | 
| 37 | 
            +
            # Disabled by default to reduce tfevents file size.
         | 
| 38 | 
            +
            enable_custom_scalars = False
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            _dtype = tf.float64
         | 
| 41 | 
            +
            _vars = OrderedDict()  # name => [var, ...]
         | 
| 42 | 
            +
            _immediate = OrderedDict()  # name => update_op, update_value
         | 
| 43 | 
            +
            _finalized = False
         | 
| 44 | 
            +
            _merge_op = None
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
         | 
| 48 | 
            +
                """Internal helper for creating autosummary accumulators."""
         | 
| 49 | 
            +
                assert not _finalized
         | 
| 50 | 
            +
                name_id = name.replace("/", "_")
         | 
| 51 | 
            +
                v = tf.cast(value_expr, _dtype)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                if v.shape.is_fully_defined():
         | 
| 54 | 
            +
                    size = np.prod(v.shape.as_list())
         | 
| 55 | 
            +
                    size_expr = tf.constant(size, dtype=_dtype)
         | 
| 56 | 
            +
                else:
         | 
| 57 | 
            +
                    size = None
         | 
| 58 | 
            +
                    size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                if size == 1:
         | 
| 61 | 
            +
                    if v.shape.ndims != 0:
         | 
| 62 | 
            +
                        v = tf.reshape(v, [])
         | 
| 63 | 
            +
                    v = [size_expr, v, tf.square(v)]
         | 
| 64 | 
            +
                else:
         | 
| 65 | 
            +
                    v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
         | 
| 66 | 
            +
                v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
         | 
| 69 | 
            +
                    var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False)  # [sum(1), sum(x), sum(x**2)]
         | 
| 70 | 
            +
                update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                if name in _vars:
         | 
| 73 | 
            +
                    _vars[name].append(var)
         | 
| 74 | 
            +
                else:
         | 
| 75 | 
            +
                    _vars[name] = [var]
         | 
| 76 | 
            +
                return update_op
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
         | 
| 80 | 
            +
                """Create a new autosummary.
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                Args:
         | 
| 83 | 
            +
                    name:     Name to use in TensorBoard
         | 
| 84 | 
            +
                    value:    TensorFlow expression or python value to track
         | 
| 85 | 
            +
                    passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                Example use of the passthru mechanism:
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                n = autosummary('l2loss', loss, passthru=n)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                This is a shorthand for the following code:
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                with tf.control_dependencies([autosummary('l2loss', loss)]):
         | 
| 94 | 
            +
                    n = tf.identity(n)
         | 
| 95 | 
            +
                """
         | 
| 96 | 
            +
                tfutil.assert_tf_initialized()
         | 
| 97 | 
            +
                name_id = name.replace("/", "_")
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                if tfutil.is_tf_expression(value):
         | 
| 100 | 
            +
                    with tf.name_scope("summary_" + name_id), tf.device(value.device):
         | 
| 101 | 
            +
                        condition = tf.convert_to_tensor(condition, name='condition')
         | 
| 102 | 
            +
                        update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
         | 
| 103 | 
            +
                        with tf.control_dependencies([update_op]):
         | 
| 104 | 
            +
                            return tf.identity(value if passthru is None else passthru)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                else:  # python scalar or numpy array
         | 
| 107 | 
            +
                    assert not tfutil.is_tf_expression(passthru)
         | 
| 108 | 
            +
                    assert not tfutil.is_tf_expression(condition)
         | 
| 109 | 
            +
                    if condition:
         | 
| 110 | 
            +
                        if name not in _immediate:
         | 
| 111 | 
            +
                            with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
         | 
| 112 | 
            +
                                update_value = tf.placeholder(_dtype)
         | 
| 113 | 
            +
                                update_op = _create_var(name, update_value)
         | 
| 114 | 
            +
                                _immediate[name] = update_op, update_value
         | 
| 115 | 
            +
                        update_op, update_value = _immediate[name]
         | 
| 116 | 
            +
                        tfutil.run(update_op, {update_value: value})
         | 
| 117 | 
            +
                    return value if passthru is None else passthru
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            def finalize_autosummaries() -> None:
         | 
| 121 | 
            +
                """Create the necessary ops to include autosummaries in TensorBoard report.
         | 
| 122 | 
            +
                Note: This should be done only once per graph.
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
                global _finalized
         | 
| 125 | 
            +
                tfutil.assert_tf_initialized()
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                if _finalized:
         | 
| 128 | 
            +
                    return None
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                _finalized = True
         | 
| 131 | 
            +
                tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                # Create summary ops.
         | 
| 134 | 
            +
                with tf.device(None), tf.control_dependencies(None):
         | 
| 135 | 
            +
                    for name, vars_list in _vars.items():
         | 
| 136 | 
            +
                        name_id = name.replace("/", "_")
         | 
| 137 | 
            +
                        with tfutil.absolute_name_scope("Autosummary/" + name_id):
         | 
| 138 | 
            +
                            moments = tf.add_n(vars_list)
         | 
| 139 | 
            +
                            moments /= moments[0]
         | 
| 140 | 
            +
                            with tf.control_dependencies([moments]):  # read before resetting
         | 
| 141 | 
            +
                                reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
         | 
| 142 | 
            +
                                with tf.name_scope(None), tf.control_dependencies(reset_ops):  # reset before reporting
         | 
| 143 | 
            +
                                    mean = moments[1]
         | 
| 144 | 
            +
                                    std = tf.sqrt(moments[2] - tf.square(moments[1]))
         | 
| 145 | 
            +
                                    tf.summary.scalar(name, mean)
         | 
| 146 | 
            +
                                    if enable_custom_scalars:
         | 
| 147 | 
            +
                                        tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
         | 
| 148 | 
            +
                                        tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                # Setup layout for custom scalars.
         | 
| 151 | 
            +
                layout = None
         | 
| 152 | 
            +
                if enable_custom_scalars:
         | 
| 153 | 
            +
                    cat_dict = OrderedDict()
         | 
| 154 | 
            +
                    for series_name in sorted(_vars.keys()):
         | 
| 155 | 
            +
                        p = series_name.split("/")
         | 
| 156 | 
            +
                        cat = p[0] if len(p) >= 2 else ""
         | 
| 157 | 
            +
                        chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
         | 
| 158 | 
            +
                        if cat not in cat_dict:
         | 
| 159 | 
            +
                            cat_dict[cat] = OrderedDict()
         | 
| 160 | 
            +
                        if chart not in cat_dict[cat]:
         | 
| 161 | 
            +
                            cat_dict[cat][chart] = []
         | 
| 162 | 
            +
                        cat_dict[cat][chart].append(series_name)
         | 
| 163 | 
            +
                    categories = []
         | 
| 164 | 
            +
                    for cat_name, chart_dict in cat_dict.items():
         | 
| 165 | 
            +
                        charts = []
         | 
| 166 | 
            +
                        for chart_name, series_names in chart_dict.items():
         | 
| 167 | 
            +
                            series = []
         | 
| 168 | 
            +
                            for series_name in series_names:
         | 
| 169 | 
            +
                                series.append(layout_pb2.MarginChartContent.Series(
         | 
| 170 | 
            +
                                    value=series_name,
         | 
| 171 | 
            +
                                    lower="xCustomScalars/" + series_name + "/margin_lo",
         | 
| 172 | 
            +
                                    upper="xCustomScalars/" + series_name + "/margin_hi"))
         | 
| 173 | 
            +
                            margin = layout_pb2.MarginChartContent(series=series)
         | 
| 174 | 
            +
                            charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
         | 
| 175 | 
            +
                        categories.append(layout_pb2.Category(title=cat_name, chart=charts))
         | 
| 176 | 
            +
                    layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
         | 
| 177 | 
            +
                return layout
         | 
| 178 | 
            +
             | 
| 179 | 
            +
            def save_summaries(file_writer, global_step=None):
         | 
| 180 | 
            +
                """Call FileWriter.add_summary() with all summaries in the default graph,
         | 
| 181 | 
            +
                automatically finalizing and merging them on the first call.
         | 
| 182 | 
            +
                """
         | 
| 183 | 
            +
                global _merge_op
         | 
| 184 | 
            +
                tfutil.assert_tf_initialized()
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                if _merge_op is None:
         | 
| 187 | 
            +
                    layout = finalize_autosummaries()
         | 
| 188 | 
            +
                    if layout is not None:
         | 
| 189 | 
            +
                        file_writer.add_summary(layout)
         | 
| 190 | 
            +
                    with tf.device(None), tf.control_dependencies(None):
         | 
| 191 | 
            +
                        _merge_op = tf.summary.merge_all()
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                file_writer.add_summary(_merge_op.eval(), global_step)
         | 
    	
        PTI/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py
    ADDED
    
    | @@ -0,0 +1,181 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         | 
| 4 | 
            +
            # and proprietary rights in and to this software, related documentation
         | 
| 5 | 
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         | 
| 6 | 
            +
            # distribution of this software and related documentation without an express
         | 
| 7 | 
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """TensorFlow custom ops builder.
         | 
| 10 | 
            +
            """
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import glob
         | 
| 13 | 
            +
            import os
         | 
| 14 | 
            +
            import re
         | 
| 15 | 
            +
            import uuid
         | 
| 16 | 
            +
            import hashlib
         | 
| 17 | 
            +
            import tempfile
         | 
| 18 | 
            +
            import shutil
         | 
| 19 | 
            +
            import tensorflow as tf
         | 
| 20 | 
            +
            from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from .. import util
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            #----------------------------------------------------------------------------
         | 
| 25 | 
            +
            # Global configs.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            cuda_cache_path = None
         | 
| 28 | 
            +
            cuda_cache_version_tag = 'v1'
         | 
| 29 | 
            +
            do_not_hash_included_headers = True # Speed up compilation by assuming that headers included by the CUDA code never change.
         | 
| 30 | 
            +
            verbose = True # Print status messages to stdout.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            #----------------------------------------------------------------------------
         | 
| 33 | 
            +
            # Internal helper funcs.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            def _find_compiler_bindir():
         | 
| 36 | 
            +
                hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
         | 
| 37 | 
            +
                if hostx64_paths != []:
         | 
| 38 | 
            +
                    return hostx64_paths[0]
         | 
| 39 | 
            +
                hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
         | 
| 40 | 
            +
                if hostx64_paths != []:
         | 
| 41 | 
            +
                    return hostx64_paths[0]
         | 
| 42 | 
            +
                hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
         | 
| 43 | 
            +
                if hostx64_paths != []:
         | 
| 44 | 
            +
                    return hostx64_paths[0]
         | 
| 45 | 
            +
                vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin'
         | 
| 46 | 
            +
                if os.path.isdir(vc_bin_dir):
         | 
| 47 | 
            +
                    return vc_bin_dir
         | 
| 48 | 
            +
                return None
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def _get_compute_cap(device):
         | 
| 51 | 
            +
                caps_str = device.physical_device_desc
         | 
| 52 | 
            +
                m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
         | 
| 53 | 
            +
                major = m.group(1)
         | 
| 54 | 
            +
                minor = m.group(2)
         | 
| 55 | 
            +
                return (major, minor)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            def _get_cuda_gpu_arch_string():
         | 
| 58 | 
            +
                gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
         | 
| 59 | 
            +
                if len(gpus) == 0:
         | 
| 60 | 
            +
                    raise RuntimeError('No GPU devices found')
         | 
| 61 | 
            +
                (major, minor) = _get_compute_cap(gpus[0])
         | 
| 62 | 
            +
                return 'sm_%s%s' % (major, minor)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            def _run_cmd(cmd):
         | 
| 65 | 
            +
                with os.popen(cmd) as pipe:
         | 
| 66 | 
            +
                    output = pipe.read()
         | 
| 67 | 
            +
                    status = pipe.close()
         | 
| 68 | 
            +
                if status is not None:
         | 
| 69 | 
            +
                    raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            def _prepare_nvcc_cli(opts):
         | 
| 72 | 
            +
                cmd = 'nvcc ' + opts.strip()
         | 
| 73 | 
            +
                cmd += ' --disable-warnings'
         | 
| 74 | 
            +
                cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
         | 
| 75 | 
            +
                cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
         | 
| 76 | 
            +
                cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
         | 
| 77 | 
            +
                cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                compiler_bindir = _find_compiler_bindir()
         | 
| 80 | 
            +
                if compiler_bindir is None:
         | 
| 81 | 
            +
                    # Require that _find_compiler_bindir succeeds on Windows.  Allow
         | 
| 82 | 
            +
                    # nvcc to use whatever is the default on Linux.
         | 
| 83 | 
            +
                    if os.name == 'nt':
         | 
| 84 | 
            +
                        raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
         | 
| 85 | 
            +
                else:
         | 
| 86 | 
            +
                    cmd += ' --compiler-bindir "%s"' % compiler_bindir
         | 
| 87 | 
            +
                cmd += ' 2>&1'
         | 
| 88 | 
            +
                return cmd
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            #----------------------------------------------------------------------------
         | 
| 91 | 
            +
            # Main entry point.
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            _plugin_cache = dict()
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            def get_plugin(cuda_file, extra_nvcc_options=[]):
         | 
| 96 | 
            +
                cuda_file_base = os.path.basename(cuda_file)
         | 
| 97 | 
            +
                cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                # Already in cache?
         | 
| 100 | 
            +
                if cuda_file in _plugin_cache:
         | 
| 101 | 
            +
                    return _plugin_cache[cuda_file]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                # Setup plugin.
         | 
| 104 | 
            +
                if verbose:
         | 
| 105 | 
            +
                    print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
         | 
| 106 | 
            +
                try:
         | 
| 107 | 
            +
                    # Hash CUDA source.
         | 
| 108 | 
            +
                    md5 = hashlib.md5()
         | 
| 109 | 
            +
                    with open(cuda_file, 'rb') as f:
         | 
| 110 | 
            +
                        md5.update(f.read())
         | 
| 111 | 
            +
                    md5.update(b'\n')
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    # Hash headers included by the CUDA code by running it through the preprocessor.
         | 
| 114 | 
            +
                    if not do_not_hash_included_headers:
         | 
| 115 | 
            +
                        if verbose:
         | 
| 116 | 
            +
                            print('Preprocessing... ', end='', flush=True)
         | 
| 117 | 
            +
                        with tempfile.TemporaryDirectory() as tmp_dir:
         | 
| 118 | 
            +
                            tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
         | 
| 119 | 
            +
                            _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
         | 
| 120 | 
            +
                            with open(tmp_file, 'rb') as f:
         | 
| 121 | 
            +
                                bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
         | 
| 122 | 
            +
                                good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
         | 
| 123 | 
            +
                                for ln in f:
         | 
| 124 | 
            +
                                    if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
         | 
| 125 | 
            +
                                        ln = ln.replace(bad_file_str, good_file_str)
         | 
| 126 | 
            +
                                        md5.update(ln)
         | 
| 127 | 
            +
                                md5.update(b'\n')
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # Select compiler configs.
         | 
| 130 | 
            +
                    compile_opts = ''
         | 
| 131 | 
            +
                    if os.name == 'nt':
         | 
| 132 | 
            +
                        compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
         | 
| 133 | 
            +
                    elif os.name == 'posix':
         | 
| 134 | 
            +
                        compile_opts += f' --compiler-options \'-fPIC\''
         | 
| 135 | 
            +
                        compile_opts += f' --compiler-options \'{" ".join(tf.sysconfig.get_compile_flags())}\''
         | 
| 136 | 
            +
                        compile_opts += f' --linker-options \'{" ".join(tf.sysconfig.get_link_flags())}\''
         | 
| 137 | 
            +
                    else:
         | 
| 138 | 
            +
                        assert False # not Windows or Linux, w00t?
         | 
| 139 | 
            +
                    compile_opts += f' --gpu-architecture={_get_cuda_gpu_arch_string()}'
         | 
| 140 | 
            +
                    compile_opts += ' --use_fast_math'
         | 
| 141 | 
            +
                    for opt in extra_nvcc_options:
         | 
| 142 | 
            +
                        compile_opts += ' ' + opt
         | 
| 143 | 
            +
                    nvcc_cmd = _prepare_nvcc_cli(compile_opts)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    # Hash build configuration.
         | 
| 146 | 
            +
                    md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
         | 
| 147 | 
            +
                    md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
         | 
| 148 | 
            +
                    md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # Compile if not already compiled.
         | 
| 151 | 
            +
                    cache_dir = util.make_cache_dir_path('tflib-cudacache') if cuda_cache_path is None else cuda_cache_path
         | 
| 152 | 
            +
                    bin_file_ext = '.dll' if os.name == 'nt' else '.so'
         | 
| 153 | 
            +
                    bin_file = os.path.join(cache_dir, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
         | 
| 154 | 
            +
                    if not os.path.isfile(bin_file):
         | 
| 155 | 
            +
                        if verbose:
         | 
| 156 | 
            +
                            print('Compiling... ', end='', flush=True)
         | 
| 157 | 
            +
                        with tempfile.TemporaryDirectory() as tmp_dir:
         | 
| 158 | 
            +
                            tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
         | 
| 159 | 
            +
                            _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
         | 
| 160 | 
            +
                            os.makedirs(cache_dir, exist_ok=True)
         | 
| 161 | 
            +
                            intermediate_file = os.path.join(cache_dir, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
         | 
| 162 | 
            +
                            shutil.copyfile(tmp_file, intermediate_file)
         | 
| 163 | 
            +
                            os.rename(intermediate_file, bin_file) # atomic
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    # Load.
         | 
| 166 | 
            +
                    if verbose:
         | 
| 167 | 
            +
                        print('Loading... ', end='', flush=True)
         | 
| 168 | 
            +
                    plugin = tf.load_op_library(bin_file)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    # Add to cache.
         | 
| 171 | 
            +
                    _plugin_cache[cuda_file] = plugin
         | 
| 172 | 
            +
                    if verbose:
         | 
| 173 | 
            +
                        print('Done.', flush=True)
         | 
| 174 | 
            +
                    return plugin
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                except:
         | 
| 177 | 
            +
                    if verbose:
         | 
| 178 | 
            +
                        print('Failed!', flush=True)
         | 
| 179 | 
            +
                    raise
         | 
| 180 | 
            +
             | 
| 181 | 
            +
            #----------------------------------------------------------------------------
         | 
    	
        PTI/models/StyleCLIP/global_directions/dnnlib/tflib/network.py
    ADDED
    
    | @@ -0,0 +1,781 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         | 
| 4 | 
            +
            # and proprietary rights in and to this software, related documentation
         | 
| 5 | 
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         | 
| 6 | 
            +
            # distribution of this software and related documentation without an express
         | 
| 7 | 
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """Helper for managing networks."""
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import types
         | 
| 12 | 
            +
            import inspect
         | 
| 13 | 
            +
            import re
         | 
| 14 | 
            +
            import uuid
         | 
| 15 | 
            +
            import sys
         | 
| 16 | 
            +
            import copy
         | 
| 17 | 
            +
            import numpy as np
         | 
| 18 | 
            +
            import tensorflow as tf
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from collections import OrderedDict
         | 
| 21 | 
            +
            from typing import Any, List, Tuple, Union, Callable
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from . import tfutil
         | 
| 24 | 
            +
            from .. import util
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from .tfutil import TfExpression, TfExpressionEx
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            # pylint: disable=protected-access
         | 
| 29 | 
            +
            # pylint: disable=attribute-defined-outside-init
         | 
| 30 | 
            +
            # pylint: disable=too-many-public-methods
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            _import_handlers = []  # Custom import handlers for dealing with legacy data in pickle import.
         | 
| 33 | 
            +
            _import_module_src = dict()  # Source code for temporary modules created during pickle import.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def import_handler(handler_func):
         | 
| 37 | 
            +
                """Function decorator for declaring custom import handlers."""
         | 
| 38 | 
            +
                _import_handlers.append(handler_func)
         | 
| 39 | 
            +
                return handler_func
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            class Network:
         | 
| 43 | 
            +
                """Generic network abstraction.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                Acts as a convenience wrapper for a parameterized network construction
         | 
| 46 | 
            +
                function, providing several utility methods and convenient access to
         | 
| 47 | 
            +
                the inputs/outputs/weights.
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                Network objects can be safely pickled and unpickled for long-term
         | 
| 50 | 
            +
                archival purposes. The pickling works reliably as long as the underlying
         | 
| 51 | 
            +
                network construction function is defined in a standalone Python module
         | 
| 52 | 
            +
                that has no side effects or application-specific imports.
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                Args:
         | 
| 55 | 
            +
                    name: Network name. Used to select TensorFlow name and variable scopes. Defaults to build func name if None.
         | 
| 56 | 
            +
                    func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
         | 
| 57 | 
            +
                    static_kwargs: Keyword arguments to be passed in to the network construction function.
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
         | 
| 61 | 
            +
                    # Locate the user-specified build function.
         | 
| 62 | 
            +
                    assert isinstance(func_name, str) or util.is_top_level_function(func_name)
         | 
| 63 | 
            +
                    if util.is_top_level_function(func_name):
         | 
| 64 | 
            +
                        func_name = util.get_top_level_function_name(func_name)
         | 
| 65 | 
            +
                    module, func_name = util.get_module_from_obj_name(func_name)
         | 
| 66 | 
            +
                    func = util.get_obj_from_module(module, func_name)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # Dig up source code for the module containing the build function.
         | 
| 69 | 
            +
                    module_src = _import_module_src.get(module, None)
         | 
| 70 | 
            +
                    if module_src is None:
         | 
| 71 | 
            +
                        module_src = inspect.getsource(module)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    # Initialize fields.
         | 
| 74 | 
            +
                    self._init_fields(name=(name or func_name), static_kwargs=static_kwargs, build_func=func, build_func_name=func_name, build_module_src=module_src)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def _init_fields(self, name: str, static_kwargs: dict, build_func: Callable, build_func_name: str, build_module_src: str) -> None:
         | 
| 77 | 
            +
                    tfutil.assert_tf_initialized()
         | 
| 78 | 
            +
                    assert isinstance(name, str)
         | 
| 79 | 
            +
                    assert len(name) >= 1
         | 
| 80 | 
            +
                    assert re.fullmatch(r"[A-Za-z0-9_.\\-]*", name)
         | 
| 81 | 
            +
                    assert isinstance(static_kwargs, dict)
         | 
| 82 | 
            +
                    assert util.is_pickleable(static_kwargs)
         | 
| 83 | 
            +
                    assert callable(build_func)
         | 
| 84 | 
            +
                    assert isinstance(build_func_name, str)
         | 
| 85 | 
            +
                    assert isinstance(build_module_src, str)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    # Choose TensorFlow name scope.
         | 
| 88 | 
            +
                    with tf.name_scope(None):
         | 
| 89 | 
            +
                        scope = tf.get_default_graph().unique_name(name, mark_as_used=True)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    # Query current TensorFlow device.
         | 
| 92 | 
            +
                    with tfutil.absolute_name_scope(scope), tf.control_dependencies(None):
         | 
| 93 | 
            +
                        device = tf.no_op(name="_QueryDevice").device
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    # Immutable state.
         | 
| 96 | 
            +
                    self._name                  = name
         | 
| 97 | 
            +
                    self._scope                 = scope
         | 
| 98 | 
            +
                    self._device                = device
         | 
| 99 | 
            +
                    self._static_kwargs         = util.EasyDict(copy.deepcopy(static_kwargs))
         | 
| 100 | 
            +
                    self._build_func            = build_func
         | 
| 101 | 
            +
                    self._build_func_name       = build_func_name
         | 
| 102 | 
            +
                    self._build_module_src      = build_module_src
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    # State before _init_graph().
         | 
| 105 | 
            +
                    self._var_inits             = dict()    # var_name => initial_value, set to None by _init_graph()
         | 
| 106 | 
            +
                    self._all_inits_known       = False     # Do we know for sure that _var_inits covers all the variables?
         | 
| 107 | 
            +
                    self._components            = None      # subnet_name => Network, None if the components are not known yet
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # Initialized by _init_graph().
         | 
| 110 | 
            +
                    self._input_templates       = None
         | 
| 111 | 
            +
                    self._output_templates      = None
         | 
| 112 | 
            +
                    self._own_vars              = None
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    # Cached values initialized the respective methods.
         | 
| 115 | 
            +
                    self._input_shapes          = None
         | 
| 116 | 
            +
                    self._output_shapes         = None
         | 
| 117 | 
            +
                    self._input_names           = None
         | 
| 118 | 
            +
                    self._output_names          = None
         | 
| 119 | 
            +
                    self._vars                  = None
         | 
| 120 | 
            +
                    self._trainables            = None
         | 
| 121 | 
            +
                    self._var_global_to_local   = None
         | 
| 122 | 
            +
                    self._run_cache             = dict()
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def _init_graph(self) -> None:
         | 
| 125 | 
            +
                    assert self._var_inits is not None
         | 
| 126 | 
            +
                    assert self._input_templates is None
         | 
| 127 | 
            +
                    assert self._output_templates is None
         | 
| 128 | 
            +
                    assert self._own_vars is None
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # Initialize components.
         | 
| 131 | 
            +
                    if self._components is None:
         | 
| 132 | 
            +
                        self._components = util.EasyDict()
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    # Choose build func kwargs.
         | 
| 135 | 
            +
                    build_kwargs = dict(self.static_kwargs)
         | 
| 136 | 
            +
                    build_kwargs["is_template_graph"] = True
         | 
| 137 | 
            +
                    build_kwargs["components"] = self._components
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    # Override scope and device, and ignore surrounding control dependencies.
         | 
| 140 | 
            +
                    with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope), tf.device(self.device), tf.control_dependencies(None):
         | 
| 141 | 
            +
                        assert tf.get_variable_scope().name == self.scope
         | 
| 142 | 
            +
                        assert tf.get_default_graph().get_name_scope() == self.scope
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                        # Create input templates.
         | 
| 145 | 
            +
                        self._input_templates = []
         | 
| 146 | 
            +
                        for param in inspect.signature(self._build_func).parameters.values():
         | 
| 147 | 
            +
                            if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
         | 
| 148 | 
            +
                                self._input_templates.append(tf.placeholder(tf.float32, name=param.name))
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                        # Call build func.
         | 
| 151 | 
            +
                        out_expr = self._build_func(*self._input_templates, **build_kwargs)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    # Collect output templates and variables.
         | 
| 154 | 
            +
                    assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
         | 
| 155 | 
            +
                    self._output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
         | 
| 156 | 
            +
                    self._own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    # Check for errors.
         | 
| 159 | 
            +
                    if len(self._input_templates) == 0:
         | 
| 160 | 
            +
                        raise ValueError("Network build func did not list any inputs.")
         | 
| 161 | 
            +
                    if len(self._output_templates) == 0:
         | 
| 162 | 
            +
                        raise ValueError("Network build func did not return any outputs.")
         | 
| 163 | 
            +
                    if any(not tfutil.is_tf_expression(t) for t in self._output_templates):
         | 
| 164 | 
            +
                        raise ValueError("Network outputs must be TensorFlow expressions.")
         | 
| 165 | 
            +
                    if any(t.shape.ndims is None for t in self._input_templates):
         | 
| 166 | 
            +
                        raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
         | 
| 167 | 
            +
                    if any(t.shape.ndims is None for t in self._output_templates):
         | 
| 168 | 
            +
                        raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
         | 
| 169 | 
            +
                    if any(not isinstance(comp, Network) for comp in self._components.values()):
         | 
| 170 | 
            +
                        raise ValueError("Components of a Network must be Networks themselves.")
         | 
| 171 | 
            +
                    if len(self._components) != len(set(comp.name for comp in self._components.values())):
         | 
| 172 | 
            +
                        raise ValueError("Components of a Network must have unique names.")
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    # Initialize variables.
         | 
| 175 | 
            +
                    if len(self._var_inits):
         | 
| 176 | 
            +
                        tfutil.set_vars({self._get_vars()[name]: value for name, value in self._var_inits.items() if name in self._get_vars()})
         | 
| 177 | 
            +
                    remaining_inits = [var.initializer for name, var in self._own_vars.items() if name not in self._var_inits]
         | 
| 178 | 
            +
                    if self._all_inits_known:
         | 
| 179 | 
            +
                        assert len(remaining_inits) == 0
         | 
| 180 | 
            +
                    else:
         | 
| 181 | 
            +
                        tfutil.run(remaining_inits)
         | 
| 182 | 
            +
                    self._var_inits = None
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                @property
         | 
| 185 | 
            +
                def name(self):
         | 
| 186 | 
            +
                    """User-specified name string."""
         | 
| 187 | 
            +
                    return self._name
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                @property
         | 
| 190 | 
            +
                def scope(self):
         | 
| 191 | 
            +
                    """Unique TensorFlow scope containing template graph and variables, derived from the user-specified name."""
         | 
| 192 | 
            +
                    return self._scope
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                @property
         | 
| 195 | 
            +
                def device(self):
         | 
| 196 | 
            +
                    """Name of the TensorFlow device that the weights of this network reside on. Determined by the current device at construction time."""
         | 
| 197 | 
            +
                    return self._device
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                @property
         | 
| 200 | 
            +
                def static_kwargs(self):
         | 
| 201 | 
            +
                    """EasyDict of arguments passed to the user-supplied build func."""
         | 
| 202 | 
            +
                    return copy.deepcopy(self._static_kwargs)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                @property
         | 
| 205 | 
            +
                def components(self):
         | 
| 206 | 
            +
                    """EasyDict of sub-networks created by the build func."""
         | 
| 207 | 
            +
                    return copy.copy(self._get_components())
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def _get_components(self):
         | 
| 210 | 
            +
                    if self._components is None:
         | 
| 211 | 
            +
                        self._init_graph()
         | 
| 212 | 
            +
                        assert self._components is not None
         | 
| 213 | 
            +
                    return self._components
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                @property
         | 
| 216 | 
            +
                def input_shapes(self):
         | 
| 217 | 
            +
                    """List of input tensor shapes, including minibatch dimension."""
         | 
| 218 | 
            +
                    if self._input_shapes is None:
         | 
| 219 | 
            +
                        self._input_shapes = [t.shape.as_list() for t in self.input_templates]
         | 
| 220 | 
            +
                    return copy.deepcopy(self._input_shapes)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                @property
         | 
| 223 | 
            +
                def output_shapes(self):
         | 
| 224 | 
            +
                    """List of output tensor shapes, including minibatch dimension."""
         | 
| 225 | 
            +
                    if self._output_shapes is None:
         | 
| 226 | 
            +
                        self._output_shapes = [t.shape.as_list() for t in self.output_templates]
         | 
| 227 | 
            +
                    return copy.deepcopy(self._output_shapes)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                @property
         | 
| 230 | 
            +
                def input_shape(self):
         | 
| 231 | 
            +
                    """Short-hand for input_shapes[0]."""
         | 
| 232 | 
            +
                    return self.input_shapes[0]
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                @property
         | 
| 235 | 
            +
                def output_shape(self):
         | 
| 236 | 
            +
                    """Short-hand for output_shapes[0]."""
         | 
| 237 | 
            +
                    return self.output_shapes[0]
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                @property
         | 
| 240 | 
            +
                def num_inputs(self):
         | 
| 241 | 
            +
                    """Number of input tensors."""
         | 
| 242 | 
            +
                    return len(self.input_shapes)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                @property
         | 
| 245 | 
            +
                def num_outputs(self):
         | 
| 246 | 
            +
                    """Number of output tensors."""
         | 
| 247 | 
            +
                    return len(self.output_shapes)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                @property
         | 
| 250 | 
            +
                def input_names(self):
         | 
| 251 | 
            +
                    """Name string for each input."""
         | 
| 252 | 
            +
                    if self._input_names is None:
         | 
| 253 | 
            +
                        self._input_names = [t.name.split("/")[-1].split(":")[0] for t in self.input_templates]
         | 
| 254 | 
            +
                    return copy.copy(self._input_names)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                @property
         | 
| 257 | 
            +
                def output_names(self):
         | 
| 258 | 
            +
                    """Name string for each output."""
         | 
| 259 | 
            +
                    if self._output_names is None:
         | 
| 260 | 
            +
                        self._output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
         | 
| 261 | 
            +
                    return copy.copy(self._output_names)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                @property
         | 
| 264 | 
            +
                def input_templates(self):
         | 
| 265 | 
            +
                    """Input placeholders in the template graph."""
         | 
| 266 | 
            +
                    if self._input_templates is None:
         | 
| 267 | 
            +
                        self._init_graph()
         | 
| 268 | 
            +
                        assert self._input_templates is not None
         | 
| 269 | 
            +
                    return copy.copy(self._input_templates)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                @property
         | 
| 272 | 
            +
                def output_templates(self):
         | 
| 273 | 
            +
                    """Output tensors in the template graph."""
         | 
| 274 | 
            +
                    if self._output_templates is None:
         | 
| 275 | 
            +
                        self._init_graph()
         | 
| 276 | 
            +
                        assert self._output_templates is not None
         | 
| 277 | 
            +
                    return copy.copy(self._output_templates)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                @property
         | 
| 280 | 
            +
                def own_vars(self):
         | 
| 281 | 
            +
                    """Variables defined by this network (local_name => var), excluding sub-networks."""
         | 
| 282 | 
            +
                    return copy.copy(self._get_own_vars())
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                def _get_own_vars(self):
         | 
| 285 | 
            +
                    if self._own_vars is None:
         | 
| 286 | 
            +
                        self._init_graph()
         | 
| 287 | 
            +
                        assert self._own_vars is not None
         | 
| 288 | 
            +
                    return self._own_vars
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                @property
         | 
| 291 | 
            +
                def vars(self):
         | 
| 292 | 
            +
                    """All variables (local_name => var)."""
         | 
| 293 | 
            +
                    return copy.copy(self._get_vars())
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def _get_vars(self):
         | 
| 296 | 
            +
                    if self._vars is None:
         | 
| 297 | 
            +
                        self._vars = OrderedDict(self._get_own_vars())
         | 
| 298 | 
            +
                        for comp in self._get_components().values():
         | 
| 299 | 
            +
                            self._vars.update((comp.name + "/" + name, var) for name, var in comp._get_vars().items())
         | 
| 300 | 
            +
                    return self._vars
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                @property
         | 
| 303 | 
            +
                def trainables(self):
         | 
| 304 | 
            +
                    """All trainable variables (local_name => var)."""
         | 
| 305 | 
            +
                    return copy.copy(self._get_trainables())
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                def _get_trainables(self):
         | 
| 308 | 
            +
                    if self._trainables is None:
         | 
| 309 | 
            +
                        self._trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
         | 
| 310 | 
            +
                    return self._trainables
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                @property
         | 
| 313 | 
            +
                def var_global_to_local(self):
         | 
| 314 | 
            +
                    """Mapping from variable global names to local names."""
         | 
| 315 | 
            +
                    return copy.copy(self._get_var_global_to_local())
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                def _get_var_global_to_local(self):
         | 
| 318 | 
            +
                    if self._var_global_to_local is None:
         | 
| 319 | 
            +
                        self._var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
         | 
| 320 | 
            +
                    return self._var_global_to_local
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                def reset_own_vars(self) -> None:
         | 
| 323 | 
            +
                    """Re-initialize all variables of this network, excluding sub-networks."""
         | 
| 324 | 
            +
                    if self._var_inits is None or self._components is None:
         | 
| 325 | 
            +
                        tfutil.run([var.initializer for var in self._get_own_vars().values()])
         | 
| 326 | 
            +
                    else:
         | 
| 327 | 
            +
                        self._var_inits.clear()
         | 
| 328 | 
            +
                        self._all_inits_known = False
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def reset_vars(self) -> None:
         | 
| 331 | 
            +
                    """Re-initialize all variables of this network, including sub-networks."""
         | 
| 332 | 
            +
                    if self._var_inits is None:
         | 
| 333 | 
            +
                        tfutil.run([var.initializer for var in self._get_vars().values()])
         | 
| 334 | 
            +
                    else:
         | 
| 335 | 
            +
                        self._var_inits.clear()
         | 
| 336 | 
            +
                        self._all_inits_known = False
         | 
| 337 | 
            +
                        if self._components is not None:
         | 
| 338 | 
            +
                            for comp in self._components.values():
         | 
| 339 | 
            +
                                comp.reset_vars()
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                def reset_trainables(self) -> None:
         | 
| 342 | 
            +
                    """Re-initialize all trainable variables of this network, including sub-networks."""
         | 
| 343 | 
            +
                    tfutil.run([var.initializer for var in self._get_trainables().values()])
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
         | 
| 346 | 
            +
                    """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).
         | 
| 347 | 
            +
                    The graph is placed on the current TensorFlow device."""
         | 
| 348 | 
            +
                    assert len(in_expr) == self.num_inputs
         | 
| 349 | 
            +
                    assert not all(expr is None for expr in in_expr)
         | 
| 350 | 
            +
                    self._get_vars()  # ensure that all variables have been created
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    # Choose build func kwargs.
         | 
| 353 | 
            +
                    build_kwargs = dict(self.static_kwargs)
         | 
| 354 | 
            +
                    build_kwargs.update(dynamic_kwargs)
         | 
| 355 | 
            +
                    build_kwargs["is_template_graph"] = False
         | 
| 356 | 
            +
                    build_kwargs["components"] = self._components
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    # Build TensorFlow graph to evaluate the network.
         | 
| 359 | 
            +
                    with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
         | 
| 360 | 
            +
                        assert tf.get_variable_scope().name == self.scope
         | 
| 361 | 
            +
                        valid_inputs = [expr for expr in in_expr if expr is not None]
         | 
| 362 | 
            +
                        final_inputs = []
         | 
| 363 | 
            +
                        for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
         | 
| 364 | 
            +
                            if expr is not None:
         | 
| 365 | 
            +
                                expr = tf.identity(expr, name=name)
         | 
| 366 | 
            +
                            else:
         | 
| 367 | 
            +
                                expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
         | 
| 368 | 
            +
                            final_inputs.append(expr)
         | 
| 369 | 
            +
                        out_expr = self._build_func(*final_inputs, **build_kwargs)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    # Propagate input shapes back to the user-specified expressions.
         | 
| 372 | 
            +
                    for expr, final in zip(in_expr, final_inputs):
         | 
| 373 | 
            +
                        if isinstance(expr, tf.Tensor):
         | 
| 374 | 
            +
                            expr.set_shape(final.shape)
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    # Express outputs in the desired format.
         | 
| 377 | 
            +
                    assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
         | 
| 378 | 
            +
                    if return_as_list:
         | 
| 379 | 
            +
                        out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
         | 
| 380 | 
            +
                    return out_expr
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
         | 
| 383 | 
            +
                    """Get the local name of a given variable, without any surrounding name scopes."""
         | 
| 384 | 
            +
                    assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
         | 
| 385 | 
            +
                    global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
         | 
| 386 | 
            +
                    return self._get_var_global_to_local()[global_name]
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
         | 
| 389 | 
            +
                    """Find variable by local or global name."""
         | 
| 390 | 
            +
                    assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
         | 
| 391 | 
            +
                    return self._get_vars()[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
         | 
| 394 | 
            +
                    """Get the value of a given variable as NumPy array.
         | 
| 395 | 
            +
                    Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
         | 
| 396 | 
            +
                    return self.find_var(var_or_local_name).eval()
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
         | 
| 399 | 
            +
                    """Set the value of a given variable based on the given NumPy array.
         | 
| 400 | 
            +
                    Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
         | 
| 401 | 
            +
                    tfutil.set_vars({self.find_var(var_or_local_name): new_value})
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                def __getstate__(self) -> dict:
         | 
| 404 | 
            +
                    """Pickle export."""
         | 
| 405 | 
            +
                    state = dict()
         | 
| 406 | 
            +
                    state["version"]            = 5
         | 
| 407 | 
            +
                    state["name"]               = self.name
         | 
| 408 | 
            +
                    state["static_kwargs"]      = dict(self.static_kwargs)
         | 
| 409 | 
            +
                    state["components"]         = dict(self.components)
         | 
| 410 | 
            +
                    state["build_module_src"]   = self._build_module_src
         | 
| 411 | 
            +
                    state["build_func_name"]    = self._build_func_name
         | 
| 412 | 
            +
                    state["variables"]          = list(zip(self._get_own_vars().keys(), tfutil.run(list(self._get_own_vars().values()))))
         | 
| 413 | 
            +
                    state["input_shapes"]       = self.input_shapes
         | 
| 414 | 
            +
                    state["output_shapes"]      = self.output_shapes
         | 
| 415 | 
            +
                    state["input_names"]        = self.input_names
         | 
| 416 | 
            +
                    state["output_names"]       = self.output_names
         | 
| 417 | 
            +
                    return state
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                def __setstate__(self, state: dict) -> None:
         | 
| 420 | 
            +
                    """Pickle import."""
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    # Execute custom import handlers.
         | 
| 423 | 
            +
                    for handler in _import_handlers:
         | 
| 424 | 
            +
                        state = handler(state)
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    # Get basic fields.
         | 
| 427 | 
            +
                    assert state["version"] in [2, 3, 4, 5]
         | 
| 428 | 
            +
                    name = state["name"]
         | 
| 429 | 
            +
                    static_kwargs = state["static_kwargs"]
         | 
| 430 | 
            +
                    build_module_src = state["build_module_src"]
         | 
| 431 | 
            +
                    build_func_name = state["build_func_name"]
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    # Create temporary module from the imported source code.
         | 
| 434 | 
            +
                    module_name = "_tflib_network_import_" + uuid.uuid4().hex
         | 
| 435 | 
            +
                    module = types.ModuleType(module_name)
         | 
| 436 | 
            +
                    sys.modules[module_name] = module
         | 
| 437 | 
            +
                    _import_module_src[module] = build_module_src
         | 
| 438 | 
            +
                    exec(build_module_src, module.__dict__) # pylint: disable=exec-used
         | 
| 439 | 
            +
                    build_func = util.get_obj_from_module(module, build_func_name)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    # Initialize fields.
         | 
| 442 | 
            +
                    self._init_fields(name=name, static_kwargs=static_kwargs, build_func=build_func, build_func_name=build_func_name, build_module_src=build_module_src)
         | 
| 443 | 
            +
                    self._var_inits.update(copy.deepcopy(state["variables"]))
         | 
| 444 | 
            +
                    self._all_inits_known   = True
         | 
| 445 | 
            +
                    self._components        = util.EasyDict(state.get("components", {}))
         | 
| 446 | 
            +
                    self._input_shapes      = copy.deepcopy(state.get("input_shapes", None))
         | 
| 447 | 
            +
                    self._output_shapes     = copy.deepcopy(state.get("output_shapes", None))
         | 
| 448 | 
            +
                    self._input_names       = copy.deepcopy(state.get("input_names", None))
         | 
| 449 | 
            +
                    self._output_names      = copy.deepcopy(state.get("output_names", None))
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                def clone(self, name: str = None, **new_static_kwargs) -> "Network":
         | 
| 452 | 
            +
                    """Create a clone of this network with its own copy of the variables."""
         | 
| 453 | 
            +
                    static_kwargs = dict(self.static_kwargs)
         | 
| 454 | 
            +
                    static_kwargs.update(new_static_kwargs)
         | 
| 455 | 
            +
                    net = object.__new__(Network)
         | 
| 456 | 
            +
                    net._init_fields(name=(name or self.name), static_kwargs=static_kwargs, build_func=self._build_func, build_func_name=self._build_func_name, build_module_src=self._build_module_src)
         | 
| 457 | 
            +
                    net.copy_vars_from(self)
         | 
| 458 | 
            +
                    return net
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                def copy_own_vars_from(self, src_net: "Network") -> None:
         | 
| 461 | 
            +
                    """Copy the values of all variables from the given network, excluding sub-networks."""
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                    # Source has unknown variables or unknown components => init now.
         | 
| 464 | 
            +
                    if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
         | 
| 465 | 
            +
                        src_net._get_vars()
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                   # Both networks are inited => copy directly.
         | 
| 468 | 
            +
                    if src_net._var_inits is None and self._var_inits is None:
         | 
| 469 | 
            +
                        names = [name for name in self._get_own_vars().keys() if name in src_net._get_own_vars()]
         | 
| 470 | 
            +
                        tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
         | 
| 471 | 
            +
                        return
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    # Read from source.
         | 
| 474 | 
            +
                    if src_net._var_inits is None:
         | 
| 475 | 
            +
                        value_dict = tfutil.run(src_net._get_own_vars())
         | 
| 476 | 
            +
                    else:
         | 
| 477 | 
            +
                        value_dict = src_net._var_inits
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                    # Write to destination.
         | 
| 480 | 
            +
                    if self._var_inits is None:
         | 
| 481 | 
            +
                        tfutil.set_vars({self._get_vars()[name]: value for name, value in value_dict.items() if name in self._get_vars()})
         | 
| 482 | 
            +
                    else:
         | 
| 483 | 
            +
                        self._var_inits.update(value_dict)
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                def copy_vars_from(self, src_net: "Network") -> None:
         | 
| 486 | 
            +
                    """Copy the values of all variables from the given network, including sub-networks."""
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    # Source has unknown variables or unknown components => init now.
         | 
| 489 | 
            +
                    if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
         | 
| 490 | 
            +
                        src_net._get_vars()
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                    # Source is inited, but destination components have not been created yet => set as initial values.
         | 
| 493 | 
            +
                    if src_net._var_inits is None and self._components is None:
         | 
| 494 | 
            +
                        self._var_inits.update(tfutil.run(src_net._get_vars()))
         | 
| 495 | 
            +
                        return
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                    # Destination has unknown components => init now.
         | 
| 498 | 
            +
                    if self._components is None:
         | 
| 499 | 
            +
                        self._get_vars()
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    # Both networks are inited => copy directly.
         | 
| 502 | 
            +
                    if src_net._var_inits is None and self._var_inits is None:
         | 
| 503 | 
            +
                        names = [name for name in self._get_vars().keys() if name in src_net._get_vars()]
         | 
| 504 | 
            +
                        tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
         | 
| 505 | 
            +
                        return
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                    # Copy recursively, component by component.
         | 
| 508 | 
            +
                    self.copy_own_vars_from(src_net)
         | 
| 509 | 
            +
                    for name, src_comp in src_net._components.items():
         | 
| 510 | 
            +
                        if name in self._components:
         | 
| 511 | 
            +
                            self._components[name].copy_vars_from(src_comp)
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                def copy_trainables_from(self, src_net: "Network") -> None:
         | 
| 514 | 
            +
                    """Copy the values of all trainable variables from the given network, including sub-networks."""
         | 
| 515 | 
            +
                    names = [name for name in self._get_trainables().keys() if name in src_net._get_trainables()]
         | 
| 516 | 
            +
                    tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
         | 
| 519 | 
            +
                    """Create new network with the given parameters, and copy all variables from this network."""
         | 
| 520 | 
            +
                    if new_name is None:
         | 
| 521 | 
            +
                        new_name = self.name
         | 
| 522 | 
            +
                    static_kwargs = dict(self.static_kwargs)
         | 
| 523 | 
            +
                    static_kwargs.update(new_static_kwargs)
         | 
| 524 | 
            +
                    net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
         | 
| 525 | 
            +
                    net.copy_vars_from(self)
         | 
| 526 | 
            +
                    return net
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
         | 
| 529 | 
            +
                    """Construct a TensorFlow op that updates the variables of this network
         | 
| 530 | 
            +
                    to be slightly closer to those of the given network."""
         | 
| 531 | 
            +
                    with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
         | 
| 532 | 
            +
                        ops = []
         | 
| 533 | 
            +
                        for name, var in self._get_vars().items():
         | 
| 534 | 
            +
                            if name in src_net._get_vars():
         | 
| 535 | 
            +
                                cur_beta = beta if var.trainable else beta_nontrainable
         | 
| 536 | 
            +
                                new_value = tfutil.lerp(src_net._get_vars()[name], var, cur_beta)
         | 
| 537 | 
            +
                                ops.append(var.assign(new_value))
         | 
| 538 | 
            +
                        return tf.group(*ops)
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                def run(self,
         | 
| 541 | 
            +
                        *in_arrays: Tuple[Union[np.ndarray, None], ...],
         | 
| 542 | 
            +
                        input_transform: dict = None,
         | 
| 543 | 
            +
                        output_transform: dict = None,
         | 
| 544 | 
            +
                        return_as_list: bool = False,
         | 
| 545 | 
            +
                        print_progress: bool = False,
         | 
| 546 | 
            +
                        minibatch_size: int = None,
         | 
| 547 | 
            +
                        num_gpus: int = 1,
         | 
| 548 | 
            +
                        assume_frozen: bool = False,
         | 
| 549 | 
            +
                        **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
         | 
| 550 | 
            +
                    """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    Args:
         | 
| 553 | 
            +
                        input_transform:    A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
         | 
| 554 | 
            +
                                            The dict must contain a 'func' field that points to a top-level function. The function is called with the input
         | 
| 555 | 
            +
                                            TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
         | 
| 556 | 
            +
                        output_transform:   A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
         | 
| 557 | 
            +
                                            The dict must contain a 'func' field that points to a top-level function. The function is called with the output
         | 
| 558 | 
            +
                                            TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
         | 
| 559 | 
            +
                        return_as_list:     True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
         | 
| 560 | 
            +
                        print_progress:     Print progress to the console? Useful for very large input arrays.
         | 
| 561 | 
            +
                        minibatch_size:     Maximum minibatch size to use, None = disable batching.
         | 
| 562 | 
            +
                        num_gpus:           Number of GPUs to use.
         | 
| 563 | 
            +
                        assume_frozen:      Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
         | 
| 564 | 
            +
                        dynamic_kwargs:     Additional keyword arguments to be passed into the network build function.
         | 
| 565 | 
            +
                    """
         | 
| 566 | 
            +
                    assert len(in_arrays) == self.num_inputs
         | 
| 567 | 
            +
                    assert not all(arr is None for arr in in_arrays)
         | 
| 568 | 
            +
                    assert input_transform is None or util.is_top_level_function(input_transform["func"])
         | 
| 569 | 
            +
                    assert output_transform is None or util.is_top_level_function(output_transform["func"])
         | 
| 570 | 
            +
                    output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
         | 
| 571 | 
            +
                    num_items = in_arrays[0].shape[0]
         | 
| 572 | 
            +
                    if minibatch_size is None:
         | 
| 573 | 
            +
                        minibatch_size = num_items
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                    # Construct unique hash key from all arguments that affect the TensorFlow graph.
         | 
| 576 | 
            +
                    key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
         | 
| 577 | 
            +
                    def unwind_key(obj):
         | 
| 578 | 
            +
                        if isinstance(obj, dict):
         | 
| 579 | 
            +
                            return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
         | 
| 580 | 
            +
                        if callable(obj):
         | 
| 581 | 
            +
                            return util.get_top_level_function_name(obj)
         | 
| 582 | 
            +
                        return obj
         | 
| 583 | 
            +
                    key = repr(unwind_key(key))
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                    # Build graph.
         | 
| 586 | 
            +
                    if key not in self._run_cache:
         | 
| 587 | 
            +
                        with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
         | 
| 588 | 
            +
                            with tf.device("/cpu:0"):
         | 
| 589 | 
            +
                                in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
         | 
| 590 | 
            +
                                in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                            out_split = []
         | 
| 593 | 
            +
                            for gpu in range(num_gpus):
         | 
| 594 | 
            +
                                with tf.device(self.device if num_gpus == 1 else "/gpu:%d" % gpu):
         | 
| 595 | 
            +
                                    net_gpu = self.clone() if assume_frozen else self
         | 
| 596 | 
            +
                                    in_gpu = in_split[gpu]
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                                    if input_transform is not None:
         | 
| 599 | 
            +
                                        in_kwargs = dict(input_transform)
         | 
| 600 | 
            +
                                        in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
         | 
| 601 | 
            +
                                        in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                                    assert len(in_gpu) == self.num_inputs
         | 
| 604 | 
            +
                                    out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                                    if output_transform is not None:
         | 
| 607 | 
            +
                                        out_kwargs = dict(output_transform)
         | 
| 608 | 
            +
                                        out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
         | 
| 609 | 
            +
                                        out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                                    assert len(out_gpu) == self.num_outputs
         | 
| 612 | 
            +
                                    out_split.append(out_gpu)
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                            with tf.device("/cpu:0"):
         | 
| 615 | 
            +
                                out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
         | 
| 616 | 
            +
                                self._run_cache[key] = in_expr, out_expr
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                    # Run minibatches.
         | 
| 619 | 
            +
                    in_expr, out_expr = self._run_cache[key]
         | 
| 620 | 
            +
                    out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                    for mb_begin in range(0, num_items, minibatch_size):
         | 
| 623 | 
            +
                        if print_progress:
         | 
| 624 | 
            +
                            print("\r%d / %d" % (mb_begin, num_items), end="")
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                        mb_end = min(mb_begin + minibatch_size, num_items)
         | 
| 627 | 
            +
                        mb_num = mb_end - mb_begin
         | 
| 628 | 
            +
                        mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
         | 
| 629 | 
            +
                        mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                        for dst, src in zip(out_arrays, mb_out):
         | 
| 632 | 
            +
                            dst[mb_begin: mb_end] = src
         | 
| 633 | 
            +
             | 
| 634 | 
            +
                    # Done.
         | 
| 635 | 
            +
                    if print_progress:
         | 
| 636 | 
            +
                        print("\r%d / %d" % (num_items, num_items))
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                    if not return_as_list:
         | 
| 639 | 
            +
                        out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
         | 
| 640 | 
            +
                    return out_arrays
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                def list_ops(self) -> List[TfExpression]:
         | 
| 643 | 
            +
                    _ = self.output_templates  # ensure that the template graph has been created
         | 
| 644 | 
            +
                    include_prefix = self.scope + "/"
         | 
| 645 | 
            +
                    exclude_prefix = include_prefix + "_"
         | 
| 646 | 
            +
                    ops = tf.get_default_graph().get_operations()
         | 
| 647 | 
            +
                    ops = [op for op in ops if op.name.startswith(include_prefix)]
         | 
| 648 | 
            +
                    ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
         | 
| 649 | 
            +
                    return ops
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
         | 
| 652 | 
            +
                    """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
         | 
| 653 | 
            +
                    individual layers of the network. Mainly intended to be used for reporting."""
         | 
| 654 | 
            +
                    layers = []
         | 
| 655 | 
            +
             | 
| 656 | 
            +
                    def recurse(scope, parent_ops, parent_vars, level):
         | 
| 657 | 
            +
                        if len(parent_ops) == 0 and len(parent_vars) == 0:
         | 
| 658 | 
            +
                            return
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                        # Ignore specific patterns.
         | 
| 661 | 
            +
                        if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
         | 
| 662 | 
            +
                            return
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                        # Filter ops and vars by scope.
         | 
| 665 | 
            +
                        global_prefix = scope + "/"
         | 
| 666 | 
            +
                        local_prefix = global_prefix[len(self.scope) + 1:]
         | 
| 667 | 
            +
                        cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
         | 
| 668 | 
            +
                        cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
         | 
| 669 | 
            +
                        if not cur_ops and not cur_vars:
         | 
| 670 | 
            +
                            return
         | 
| 671 | 
            +
             | 
| 672 | 
            +
                        # Filter out all ops related to variables.
         | 
| 673 | 
            +
                        for var in [op for op in cur_ops if op.type.startswith("Variable")]:
         | 
| 674 | 
            +
                            var_prefix = var.name + "/"
         | 
| 675 | 
            +
                            cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
         | 
| 676 | 
            +
             | 
| 677 | 
            +
                        # Scope does not contain ops as immediate children => recurse deeper.
         | 
| 678 | 
            +
                        contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops)
         | 
| 679 | 
            +
                        if (level == 0 or not contains_direct_ops) and (len(cur_ops) != 0 or len(cur_vars) != 0):
         | 
| 680 | 
            +
                            visited = set()
         | 
| 681 | 
            +
                            for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
         | 
| 682 | 
            +
                                token = rel_name.split("/")[0]
         | 
| 683 | 
            +
                                if token not in visited:
         | 
| 684 | 
            +
                                    recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
         | 
| 685 | 
            +
                                    visited.add(token)
         | 
| 686 | 
            +
                            return
         | 
| 687 | 
            +
             | 
| 688 | 
            +
                        # Report layer.
         | 
| 689 | 
            +
                        layer_name = scope[len(self.scope) + 1:]
         | 
| 690 | 
            +
                        layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
         | 
| 691 | 
            +
                        layer_trainables = [var for _name, var in cur_vars if var.trainable]
         | 
| 692 | 
            +
                        layers.append((layer_name, layer_output, layer_trainables))
         | 
| 693 | 
            +
             | 
| 694 | 
            +
                    recurse(self.scope, self.list_ops(), list(self._get_vars().items()), 0)
         | 
| 695 | 
            +
                    return layers
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
         | 
| 698 | 
            +
                    """Print a summary table of the network structure."""
         | 
| 699 | 
            +
                    rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
         | 
| 700 | 
            +
                    rows += [["---"] * 4]
         | 
| 701 | 
            +
                    total_params = 0
         | 
| 702 | 
            +
             | 
| 703 | 
            +
                    for layer_name, layer_output, layer_trainables in self.list_layers():
         | 
| 704 | 
            +
                        num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables)
         | 
| 705 | 
            +
                        weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
         | 
| 706 | 
            +
                        weights.sort(key=lambda x: len(x.name))
         | 
| 707 | 
            +
                        if len(weights) == 0 and len(layer_trainables) == 1:
         | 
| 708 | 
            +
                            weights = layer_trainables
         | 
| 709 | 
            +
                        total_params += num_params
         | 
| 710 | 
            +
             | 
| 711 | 
            +
                        if not hide_layers_with_no_params or num_params != 0:
         | 
| 712 | 
            +
                            num_params_str = str(num_params) if num_params > 0 else "-"
         | 
| 713 | 
            +
                            output_shape_str = str(layer_output.shape)
         | 
| 714 | 
            +
                            weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
         | 
| 715 | 
            +
                            rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                    rows += [["---"] * 4]
         | 
| 718 | 
            +
                    rows += [["Total", str(total_params), "", ""]]
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                    widths = [max(len(cell) for cell in column) for column in zip(*rows)]
         | 
| 721 | 
            +
                    print()
         | 
| 722 | 
            +
                    for row in rows:
         | 
| 723 | 
            +
                        print("  ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
         | 
| 724 | 
            +
                    print()
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                def setup_weight_histograms(self, title: str = None) -> None:
         | 
| 727 | 
            +
                    """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
         | 
| 728 | 
            +
                    if title is None:
         | 
| 729 | 
            +
                        title = self.name
         | 
| 730 | 
            +
             | 
| 731 | 
            +
                    with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
         | 
| 732 | 
            +
                        for local_name, var in self._get_trainables().items():
         | 
| 733 | 
            +
                            if "/" in local_name:
         | 
| 734 | 
            +
                                p = local_name.split("/")
         | 
| 735 | 
            +
                                name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
         | 
| 736 | 
            +
                            else:
         | 
| 737 | 
            +
                                name = title + "_toplevel/" + local_name
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                            tf.summary.histogram(name, var)
         | 
| 740 | 
            +
             | 
| 741 | 
            +
            #----------------------------------------------------------------------------
         | 
| 742 | 
            +
            # Backwards-compatible emulation of legacy output transformation in Network.run().
         | 
| 743 | 
            +
             | 
| 744 | 
            +
            _print_legacy_warning = True
         | 
| 745 | 
            +
             | 
| 746 | 
            +
            def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
         | 
| 747 | 
            +
                global _print_legacy_warning
         | 
| 748 | 
            +
                legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
         | 
| 749 | 
            +
                if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
         | 
| 750 | 
            +
                    return output_transform, dynamic_kwargs
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                if _print_legacy_warning:
         | 
| 753 | 
            +
                    _print_legacy_warning = False
         | 
| 754 | 
            +
                    print()
         | 
| 755 | 
            +
                    print("WARNING: Old-style output transformations in Network.run() are deprecated.")
         | 
| 756 | 
            +
                    print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
         | 
| 757 | 
            +
                    print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
         | 
| 758 | 
            +
                    print()
         | 
| 759 | 
            +
                assert output_transform is None
         | 
| 760 | 
            +
             | 
| 761 | 
            +
                new_kwargs = dict(dynamic_kwargs)
         | 
| 762 | 
            +
                new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
         | 
| 763 | 
            +
                new_transform["func"] = _legacy_output_transform_func
         | 
| 764 | 
            +
                return new_transform, new_kwargs
         | 
| 765 | 
            +
             | 
| 766 | 
            +
            def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
         | 
| 767 | 
            +
                if out_mul != 1.0:
         | 
| 768 | 
            +
                    expr = [x * out_mul for x in expr]
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                if out_add != 0.0:
         | 
| 771 | 
            +
                    expr = [x + out_add for x in expr]
         | 
| 772 | 
            +
             | 
| 773 | 
            +
                if out_shrink > 1:
         | 
| 774 | 
            +
                    ksize = [1, 1, out_shrink, out_shrink]
         | 
| 775 | 
            +
                    expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
         | 
| 776 | 
            +
             | 
| 777 | 
            +
                if out_dtype is not None:
         | 
| 778 | 
            +
                    if tf.as_dtype(out_dtype).is_integer:
         | 
| 779 | 
            +
                        expr = [tf.round(x) for x in expr]
         | 
| 780 | 
            +
                    expr = [tf.saturate_cast(x, out_dtype) for x in expr]
         | 
| 781 | 
            +
                return expr
         | 
    	
        PTI/models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         | 
| 4 | 
            +
            # and proprietary rights in and to this software, related documentation
         | 
| 5 | 
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         | 
| 6 | 
            +
            # distribution of this software and related documentation without an express
         | 
| 7 | 
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # empty
         | 

