max commited on
Commit
b6dd358
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +32 -0
  2. LICENSE +162 -0
  3. README.md +13 -0
  4. app.py +327 -0
  5. dataset_tool.py +444 -0
  6. datasets/dataset_256.py +286 -0
  7. datasets/dataset_256_val.py +282 -0
  8. datasets/dataset_512.py +286 -0
  9. datasets/dataset_512_val.py +282 -0
  10. datasets/mask_generator_256.py +93 -0
  11. datasets/mask_generator_256_small.py +93 -0
  12. datasets/mask_generator_512.py +93 -0
  13. datasets/mask_generator_512_small.py +93 -0
  14. dnnlib/__init__.py +9 -0
  15. dnnlib/util.py +477 -0
  16. evaluatoin/cal_fid_pids_uids.py +193 -0
  17. evaluatoin/cal_lpips.py +71 -0
  18. evaluatoin/cal_psnr_ssim_l1.py +107 -0
  19. legacy.py +323 -0
  20. losses/loss.py +170 -0
  21. losses/pcp.py +126 -0
  22. losses/vggNet.py +178 -0
  23. metrics/__init__.py +9 -0
  24. metrics/frechet_inception_distance.py +41 -0
  25. metrics/inception_discriminative_score.py +37 -0
  26. metrics/inception_score.py +38 -0
  27. metrics/kernel_inception_distance.py +46 -0
  28. metrics/metric_main.py +184 -0
  29. metrics/metric_utils.py +434 -0
  30. metrics/perceptual_path_length.py +131 -0
  31. metrics/precision_recall.py +62 -0
  32. metrics/psnr_ssim_l1.py +19 -0
  33. models/Places_512_FullData+LAION300k.pkl +3 -0
  34. models/Places_512_FullData.pkl +3 -0
  35. networks/basic_module.py +583 -0
  36. networks/mat.py +996 -0
  37. op.gif +3 -0
  38. requirements.txt +16 -0
  39. test_sets/CelebA-HQ/images/test1.png +0 -0
  40. test_sets/CelebA-HQ/images/test2.png +0 -0
  41. test_sets/CelebA-HQ/masks/mask1.png +0 -0
  42. test_sets/CelebA-HQ/masks/mask2.png +0 -0
  43. test_sets/Places/images/test1.jpg +0 -0
  44. test_sets/Places/images/test2.jpg +0 -0
  45. test_sets/Places/masks/mask1.png +0 -0
  46. test_sets/Places/masks/mask2.png +0 -0
  47. torch_utils/__init__.py +9 -0
  48. torch_utils/custom_ops.py +126 -0
  49. torch_utils/misc.py +268 -0
  50. torch_utils/ops/__init__.py +9 -0
.gitattributes ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.gif filter=lfs diff=lfs merge=lfs -text
3
+ *.arrow filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pickle filter=lfs diff=lfs merge=lfs -text
19
+ *.pkl filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
26
+ *.tflite filter=lfs diff=lfs merge=lfs -text
27
+ *.tgz filter=lfs diff=lfs merge=lfs -text
28
+ *.wasm filter=lfs diff=lfs merge=lfs -text
29
+ *.xz filter=lfs diff=lfs merge=lfs -text
30
+ *.zip filter=lfs diff=lfs merge=lfs -text
31
+ *.zst filter=lfs diff=lfs merge=lfs -text
32
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## creative commons
2
+
3
+ # Attribution-NonCommercial 4.0 International
4
+
5
+ Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
6
+
7
+ ### Using Creative Commons Public Licenses
8
+
9
+ Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
10
+
11
+ * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
12
+
13
+ * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
14
+
15
+ ## Creative Commons Attribution-NonCommercial 4.0 International Public License
16
+
17
+ By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
18
+
19
+ ### Section 1 – Definitions.
20
+
21
+ a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
22
+
23
+ b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
24
+
25
+ c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
26
+
27
+ d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
28
+
29
+ e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
30
+
31
+ f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
32
+
33
+ g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
34
+
35
+ h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
36
+
37
+ i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
38
+
39
+ j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
40
+
41
+ k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
42
+
43
+ l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
44
+
45
+ ### Section 2 – Scope.
46
+
47
+ a. ___License grant.___
48
+
49
+ 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
50
+
51
+ A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
52
+
53
+ B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
54
+
55
+ 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
56
+
57
+ 3. __Term.__ The term of this Public License is specified in Section 6(a).
58
+
59
+ 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
60
+
61
+ 5. __Downstream recipients.__
62
+
63
+ A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
64
+
65
+ B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
66
+
67
+ 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
68
+
69
+ b. ___Other rights.___
70
+
71
+ 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
72
+
73
+ 2. Patent and trademark rights are not licensed under this Public License.
74
+
75
+ 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
76
+
77
+ ### Section 3 – License Conditions.
78
+
79
+ Your exercise of the Licensed Rights is expressly made subject to the following conditions.
80
+
81
+ a. ___Attribution.___
82
+
83
+ 1. If You Share the Licensed Material (including in modified form), You must:
84
+
85
+ A. retain the following if it is supplied by the Licensor with the Licensed Material:
86
+
87
+ i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
88
+
89
+ ii. a copyright notice;
90
+
91
+ iii. a notice that refers to this Public License;
92
+
93
+ iv. a notice that refers to the disclaimer of warranties;
94
+
95
+ v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
96
+
97
+ B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
98
+
99
+ C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
100
+
101
+ 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
102
+
103
+ 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
104
+
105
+ 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
106
+
107
+ ### Section 4 – Sui Generis Database Rights.
108
+
109
+ Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
110
+
111
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
112
+
113
+ b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
114
+
115
+ c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
116
+
117
+ For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
118
+
119
+ ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
120
+
121
+ a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
122
+
123
+ b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
124
+
125
+ c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
126
+
127
+ ### Section 6 – Term and Termination.
128
+
129
+ a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
130
+
131
+ b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
132
+
133
+ 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
134
+
135
+ 2. upon express reinstatement by the Licensor.
136
+
137
+ For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
138
+
139
+ c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
140
+
141
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
142
+
143
+ ### Section 7 – Other Terms and Conditions.
144
+
145
+ a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
146
+
147
+ b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
148
+
149
+ ### Section 8 – Interpretation.
150
+
151
+ a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
152
+
153
+ b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
154
+
155
+ c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
156
+
157
+ d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
158
+
159
+ > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
160
+ >
161
+ > Creative Commons may be contacted at creativecommons.org
162
+
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stable Diffusion Mat Outpainting Primer
3
+ emoji: 🐢
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.4
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ from networks.mat import Generator
12
+ import gradio as gr
13
+ import gradio.components as gc
14
+ import base64
15
+ import glob
16
+ import os
17
+ import random
18
+ import re
19
+ from http import HTTPStatus
20
+ from io import BytesIO
21
+ from typing import Dict, List, NamedTuple, Optional, Tuple
22
+
23
+ import click
24
+ import cv2
25
+ import numpy as np
26
+ import PIL.Image
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from PIL import Image, ImageDraw, ImageOps
30
+ from pydantic import BaseModel
31
+
32
+ import dnnlib
33
+ import legacy
34
+
35
+
36
+ pyspng = None
37
+
38
+
39
+ def num_range(s: str) -> List[int]:
40
+ '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
41
+
42
+ range_re = re.compile(r'^(\d+)-(\d+)$')
43
+ m = range_re.match(s)
44
+ if m:
45
+ return list(range(int(m.group(1)), int(m.group(2))+1))
46
+ vals = s.split(',')
47
+ return [int(x) for x in vals]
48
+
49
+
50
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
51
+ assert isinstance(src_module, torch.nn.Module)
52
+ assert isinstance(dst_module, torch.nn.Module)
53
+ src_tensors = {name: tensor for name,
54
+ tensor in named_params_and_buffers(src_module)}
55
+ for name, tensor in named_params_and_buffers(dst_module):
56
+ assert (name in src_tensors) or (not require_all)
57
+ if name in src_tensors:
58
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(
59
+ tensor.requires_grad)
60
+
61
+
62
+ def params_and_buffers(module):
63
+ assert isinstance(module, torch.nn.Module)
64
+ return list(module.parameters()) + list(module.buffers())
65
+
66
+
67
+ def named_params_and_buffers(module):
68
+ assert isinstance(module, torch.nn.Module)
69
+ return list(module.named_parameters()) + list(module.named_buffers())
70
+
71
+
72
+ class Inpainter:
73
+ def __init__(self,
74
+ network_pkl,
75
+ resolution=512,
76
+ truncation_psi=1,
77
+ noise_mode='const',
78
+ sdevice='cpu'
79
+ ):
80
+ self.resolution = resolution
81
+ self.truncation_psi = truncation_psi
82
+ self.noise_mode = noise_mode
83
+ print(f'Loading networks from: {network_pkl}')
84
+ self.device = torch.device(sdevice)
85
+ with dnnlib.util.open_url(network_pkl) as f:
86
+ G_saved = (
87
+ legacy.load_network_pkl(f)
88
+ ['G_ema']
89
+ .to(self.device)
90
+ .eval()
91
+ .requires_grad_(False)) # type: ignore
92
+ net_res = 512 if resolution > 512 else resolution
93
+ self.G = (
94
+ Generator(
95
+ z_dim=512,
96
+ c_dim=0,
97
+ w_dim=512,
98
+ img_resolution=net_res,
99
+ img_channels=3
100
+ )
101
+ .to(self.device)
102
+ .eval()
103
+ .requires_grad_(False)
104
+ )
105
+ copy_params_and_buffers(G_saved, self.G, require_all=True)
106
+
107
+ def generate_images2(
108
+ self,
109
+ dpath: List[PIL.Image.Image],
110
+ mpath: List[Optional[PIL.Image.Image]],
111
+ seed: int = 42,
112
+ ):
113
+ """
114
+ Generate images using pretrained network pickle.
115
+ """
116
+ resolution = self.resolution
117
+ truncation_psi = self.truncation_psi
118
+ noise_mode = self.noise_mode
119
+ # seed = 240 # pick up a random number
120
+
121
+ def seed_all(seed):
122
+ random.seed(seed)
123
+ np.random.seed(seed)
124
+ torch.manual_seed(seed)
125
+ torch.cuda.manual_seed(seed)
126
+ if seed is not None:
127
+ seed_all(seed)
128
+
129
+ # no Labels.
130
+ label = torch.zeros([1, self.G.c_dim], device=self.device)
131
+
132
+ def read_image(image):
133
+ image = np.array(image)
134
+ if image.ndim == 2:
135
+ image = image[:, :, np.newaxis] # HW => HWC
136
+ image = np.repeat(image, 3, axis=2)
137
+ image = image.transpose(2, 0, 1) # HWC => CHW
138
+ image = image[:3]
139
+ return image
140
+ if resolution != 512:
141
+ noise_mode = 'random'
142
+ results = []
143
+ with torch.no_grad():
144
+ for i, (ipath, m) in enumerate(zip(dpath, mpath)):
145
+ if seed is None:
146
+ seed_all(i)
147
+
148
+ image = read_image(ipath)
149
+ image = (torch.from_numpy(image).float().to(
150
+ self. device) / 127.5 - 1).unsqueeze(0)
151
+
152
+ mask = np.array(m).astype(np.float32) / 255.0
153
+ mask = torch.from_numpy(mask).float().to(
154
+ self. device).unsqueeze(0).unsqueeze(0)
155
+
156
+ z = torch.from_numpy(np.random.randn(
157
+ 1, self.G.z_dim)).to(self.device)
158
+ output = self.G(image, mask, z, label,
159
+ truncation_psi=truncation_psi, noise_mode=noise_mode)
160
+ output = (output.permute(0, 2, 3, 1) * 127.5 +
161
+ 127.5).round().clamp(0, 255).to(torch.uint8)
162
+ output = output[0].cpu().numpy()
163
+ results.append(PIL.Image.fromarray(output, 'RGB'))
164
+
165
+ return results
166
+
167
+
168
+ # if __name__ == "__main__":
169
+ # generate_images() # pylint: disable=no-value-for-parameter
170
+
171
+ # ----------------------------------------------------------------------------
172
+ def mask_to_alpha(img, mask):
173
+ img = img.copy()
174
+ img.putalpha(mask)
175
+ return img
176
+
177
+
178
+ def blend(src, target, mask):
179
+ mask = np.expand_dims(mask, axis=-1)
180
+ result = (1-mask) * src + mask * target
181
+ return Image.fromarray(result.astype(np.uint8))
182
+
183
+
184
+ def pad(img, size=(128, 128), tosize=(512, 512), border=1):
185
+ if isinstance(size, float):
186
+ size = (int(img.size[0] * size), int(img.size[1] * size))
187
+ # remove border
188
+ w, h = tosize
189
+
190
+ new_img = Image.new('RGBA', (w, h))
191
+
192
+ rimg = img.resize(size, resample=Image.Resampling.NEAREST)
193
+ rimg = ImageOps.crop(rimg, border=border)
194
+ tw, th = size
195
+ tw, th = tw - border*2, th - border*2
196
+ tc = ((w-tw)//2, (h-th)//2)
197
+
198
+ new_img.paste(rimg, tc)
199
+ mask = Image.new('L', (w, h))
200
+ white = Image.new('L', (tw, th), 255)
201
+ mask.paste(white, tc)
202
+
203
+ if 'A' in rimg.getbands():
204
+ mask.paste(img.getchannel('A'), tc)
205
+ return new_img, mask
206
+
207
+
208
+ def b64_to_img(b64):
209
+ return Image.open(BytesIO(base64.b64decode(b64)))
210
+
211
+
212
+ def img_to_b64(img):
213
+ with BytesIO() as f:
214
+ img.save(f, format='PNG')
215
+ return base64.b64encode(f.getvalue()).decode('utf-8')
216
+
217
+
218
+ class Predictor:
219
+ def __init__(self):
220
+ """Load the model into memory to make running multiple predictions efficient"""
221
+ self.models = {
222
+ "places2": Inpainter(
223
+ network_pkl='models/Places_512_FullData.pkl',
224
+ resolution=512,
225
+ truncation_psi=1.,
226
+ noise_mode='const',
227
+ ),
228
+ "places2+laion300k": Inpainter(
229
+ network_pkl='models/Places_512_FullData+LAION300k.pkl',
230
+ resolution=512,
231
+ truncation_psi=1.,
232
+ noise_mode='const',
233
+ ),
234
+ }
235
+
236
+ # The arguments and types the model takes as input
237
+
238
+ def predict(
239
+ self,
240
+ img: Image.Image,
241
+ tosize=(512, 512),
242
+ border=5,
243
+ seed=42,
244
+ size=0.5,
245
+ model='places2',
246
+ ) -> Image:
247
+ i, m = pad(
248
+ img,
249
+ size=size, # (328, 328),
250
+ tosize=tosize,
251
+ border=border
252
+ )
253
+ """Run a single prediction on the model"""
254
+ imgs = self.models[model].generate_images2(
255
+ dpath=[i.resize((512, 512), resample=Image.Resampling.NEAREST)],
256
+ mpath=[m.resize((512, 512), resample=Image.Resampling.NEAREST)],
257
+ seed=seed,
258
+ )
259
+ img_op_raw = imgs[0].convert('RGBA')
260
+ img_op_raw = img_op_raw.resize(
261
+ tosize, resample=Image.Resampling.NEAREST)
262
+ inpainted = img_op_raw.copy()
263
+
264
+ # paste original image to remove inpainting/scaling artifacts
265
+ inpainted = blend(
266
+ i,
267
+ inpainted,
268
+ 1-(np.array(m) / 255)
269
+ )
270
+ minpainted = mask_to_alpha(inpainted, m)
271
+ return minpainted, inpainted, ImageOps.invert(m)
272
+
273
+
274
+ predictor = Predictor()
275
+
276
+ # %%
277
+
278
+
279
+ def _outpaint(img, tosize, border, seed, size, model):
280
+ img_op = predictor.predict(
281
+ img,
282
+ border=border,
283
+ seed=seed,
284
+ tosize=(tosize, tosize),
285
+ size=float(size),
286
+ model=model,
287
+ )
288
+ return img_op
289
+ # %%
290
+
291
+
292
+ searchimage = gc.Image(shape=(224, 224), label="image", type='pil')
293
+ to_size = gc.Slider(1, 1920, 512, step=1, label='output size')
294
+ border = gc.Slider(
295
+ 1, 50, 0, step=1, label='border to crop from the image before outpainting')
296
+ seed = gc.Slider(1, 65536, 10, step=1, label='seed')
297
+ size = gc.Slider(0, 1, .5, step=0.01,
298
+ label='scale of the image before outpainting')
299
+
300
+ out = gc.Image(label="primed image with alpha channel", type='pil')
301
+ outwithoutalpha = gc.Image(
302
+ label="primed image without alpha channel", type='pil')
303
+ mask = gc.Image(label="outpainting mask", type='pil')
304
+
305
+ model = gc.Dropdown(
306
+ choices=['places2', 'places2+laion300k'],
307
+ value='places2',
308
+ label='model',
309
+ )
310
+
311
+
312
+ maturl = 'https://github.com/fenglinglwb/MAT'
313
+ gr.Interface(
314
+ _outpaint,
315
+ [searchimage, to_size, border, seed, size, model],
316
+ [out, outwithoutalpha, mask],
317
+ title=f"MAT Primer for Stable Diffusion\n\nbased on MAT: Mask-Aware Transformer for Large Hole Image Inpainting\n\n{maturl}",
318
+ description=f"""<html>
319
+ create an primer for use in stable diffusion outpainting<br>
320
+ example with strength 0.5
321
+ <img src='file/op.gif' />
322
+ </html>""",
323
+ analytics_enabled=False,
324
+ allow_flagging='never',
325
+
326
+
327
+ ).launch()
dataset_tool.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import functools
10
+ import io
11
+ import json
12
+ import os
13
+ import pickle
14
+ import sys
15
+ import tarfile
16
+ import gzip
17
+ import zipfile
18
+ from pathlib import Path
19
+ from typing import Callable, Optional, Tuple, Union
20
+
21
+ import click
22
+ import numpy as np
23
+ import PIL.Image
24
+ from tqdm import tqdm
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ def error(msg):
29
+ print('Error: ' + msg)
30
+ sys.exit(1)
31
+
32
+ #----------------------------------------------------------------------------
33
+
34
+ def maybe_min(a: int, b: Optional[int]) -> int:
35
+ if b is not None:
36
+ return min(a, b)
37
+ return a
38
+
39
+ #----------------------------------------------------------------------------
40
+
41
+ def file_ext(name: Union[str, Path]) -> str:
42
+ return str(name).split('.')[-1]
43
+
44
+ #----------------------------------------------------------------------------
45
+
46
+ def is_image_ext(fname: Union[str, Path]) -> bool:
47
+ ext = file_ext(fname).lower()
48
+ return f'.{ext}' in PIL.Image.EXTENSION # type: ignore
49
+
50
+ #----------------------------------------------------------------------------
51
+
52
+ def open_image_folder(source_dir, *, max_images: Optional[int]):
53
+ input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
54
+
55
+ # Load labels.
56
+ labels = {}
57
+ meta_fname = os.path.join(source_dir, 'dataset.json')
58
+ if os.path.isfile(meta_fname):
59
+ with open(meta_fname, 'r') as file:
60
+ labels = json.load(file)['labels']
61
+ if labels is not None:
62
+ labels = { x[0]: x[1] for x in labels }
63
+ else:
64
+ labels = {}
65
+
66
+ max_idx = maybe_min(len(input_images), max_images)
67
+
68
+ def iterate_images():
69
+ for idx, fname in enumerate(input_images):
70
+ arch_fname = os.path.relpath(fname, source_dir)
71
+ arch_fname = arch_fname.replace('\\', '/')
72
+ img = np.array(PIL.Image.open(fname))
73
+ yield dict(img=img, label=labels.get(arch_fname))
74
+ if idx >= max_idx-1:
75
+ break
76
+ return max_idx, iterate_images()
77
+
78
+ #----------------------------------------------------------------------------
79
+
80
+ def open_image_zip(source, *, max_images: Optional[int]):
81
+ with zipfile.ZipFile(source, mode='r') as z:
82
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
83
+
84
+ # Load labels.
85
+ labels = {}
86
+ if 'dataset.json' in z.namelist():
87
+ with z.open('dataset.json', 'r') as file:
88
+ labels = json.load(file)['labels']
89
+ if labels is not None:
90
+ labels = { x[0]: x[1] for x in labels }
91
+ else:
92
+ labels = {}
93
+
94
+ max_idx = maybe_min(len(input_images), max_images)
95
+
96
+ def iterate_images():
97
+ with zipfile.ZipFile(source, mode='r') as z:
98
+ for idx, fname in enumerate(input_images):
99
+ with z.open(fname, 'r') as file:
100
+ img = PIL.Image.open(file) # type: ignore
101
+ img = np.array(img)
102
+ yield dict(img=img, label=labels.get(fname))
103
+ if idx >= max_idx-1:
104
+ break
105
+ return max_idx, iterate_images()
106
+
107
+ #----------------------------------------------------------------------------
108
+
109
+ def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
110
+ import cv2 # pip install opencv-python
111
+ import lmdb # pip install lmdb # pylint: disable=import-error
112
+
113
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
114
+ max_idx = maybe_min(txn.stat()['entries'], max_images)
115
+
116
+ def iterate_images():
117
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
118
+ for idx, (_key, value) in enumerate(txn.cursor()):
119
+ try:
120
+ try:
121
+ img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
122
+ if img is None:
123
+ raise IOError('cv2.imdecode failed')
124
+ img = img[:, :, ::-1] # BGR => RGB
125
+ except IOError:
126
+ img = np.array(PIL.Image.open(io.BytesIO(value)))
127
+ yield dict(img=img, label=None)
128
+ if idx >= max_idx-1:
129
+ break
130
+ except:
131
+ print(sys.exc_info()[1])
132
+
133
+ return max_idx, iterate_images()
134
+
135
+ #----------------------------------------------------------------------------
136
+
137
+ def open_cifar10(tarball: str, *, max_images: Optional[int]):
138
+ images = []
139
+ labels = []
140
+
141
+ with tarfile.open(tarball, 'r:gz') as tar:
142
+ for batch in range(1, 6):
143
+ member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
144
+ with tar.extractfile(member) as file:
145
+ data = pickle.load(file, encoding='latin1')
146
+ images.append(data['data'].reshape(-1, 3, 32, 32))
147
+ labels.append(data['labels'])
148
+
149
+ images = np.concatenate(images)
150
+ labels = np.concatenate(labels)
151
+ images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
152
+ assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
153
+ assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
154
+ assert np.min(images) == 0 and np.max(images) == 255
155
+ assert np.min(labels) == 0 and np.max(labels) == 9
156
+
157
+ max_idx = maybe_min(len(images), max_images)
158
+
159
+ def iterate_images():
160
+ for idx, img in enumerate(images):
161
+ yield dict(img=img, label=int(labels[idx]))
162
+ if idx >= max_idx-1:
163
+ break
164
+
165
+ return max_idx, iterate_images()
166
+
167
+ #----------------------------------------------------------------------------
168
+
169
+ def open_mnist(images_gz: str, *, max_images: Optional[int]):
170
+ labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
171
+ assert labels_gz != images_gz
172
+ images = []
173
+ labels = []
174
+
175
+ with gzip.open(images_gz, 'rb') as f:
176
+ images = np.frombuffer(f.read(), np.uint8, offset=16)
177
+ with gzip.open(labels_gz, 'rb') as f:
178
+ labels = np.frombuffer(f.read(), np.uint8, offset=8)
179
+
180
+ images = images.reshape(-1, 28, 28)
181
+ images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
182
+ assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
183
+ assert labels.shape == (60000,) and labels.dtype == np.uint8
184
+ assert np.min(images) == 0 and np.max(images) == 255
185
+ assert np.min(labels) == 0 and np.max(labels) == 9
186
+
187
+ max_idx = maybe_min(len(images), max_images)
188
+
189
+ def iterate_images():
190
+ for idx, img in enumerate(images):
191
+ yield dict(img=img, label=int(labels[idx]))
192
+ if idx >= max_idx-1:
193
+ break
194
+
195
+ return max_idx, iterate_images()
196
+
197
+ #----------------------------------------------------------------------------
198
+
199
+ def make_transform(
200
+ transform: Optional[str],
201
+ output_width: Optional[int],
202
+ output_height: Optional[int],
203
+ resize_filter: str
204
+ ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
205
+ resample = { 'box': PIL.Image.BOX, 'lanczos': PIL.Image.LANCZOS }[resize_filter]
206
+ def scale(width, height, img):
207
+ w = img.shape[1]
208
+ h = img.shape[0]
209
+ if width == w and height == h:
210
+ return img
211
+ img = PIL.Image.fromarray(img)
212
+ ww = width if width is not None else w
213
+ hh = height if height is not None else h
214
+ img = img.resize((ww, hh), resample)
215
+ return np.array(img)
216
+
217
+ def center_crop(width, height, img):
218
+ crop = np.min(img.shape[:2])
219
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
220
+ img = PIL.Image.fromarray(img, 'RGB')
221
+ img = img.resize((width, height), resample)
222
+ return np.array(img)
223
+
224
+ def center_crop_wide(width, height, img):
225
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
226
+ if img.shape[1] < width or ch < height:
227
+ return None
228
+
229
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
230
+ img = PIL.Image.fromarray(img, 'RGB')
231
+ img = img.resize((width, height), resample)
232
+ img = np.array(img)
233
+
234
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
235
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
236
+ return canvas
237
+
238
+ if transform is None:
239
+ return functools.partial(scale, output_width, output_height)
240
+ if transform == 'center-crop':
241
+ if (output_width is None) or (output_height is None):
242
+ error ('must specify --width and --height when using ' + transform + 'transform')
243
+ return functools.partial(center_crop, output_width, output_height)
244
+ if transform == 'center-crop-wide':
245
+ if (output_width is None) or (output_height is None):
246
+ error ('must specify --width and --height when using ' + transform + ' transform')
247
+ return functools.partial(center_crop_wide, output_width, output_height)
248
+ assert False, 'unknown transform'
249
+
250
+ #----------------------------------------------------------------------------
251
+
252
+ def open_dataset(source, *, max_images: Optional[int]):
253
+ if os.path.isdir(source):
254
+ if source.rstrip('/').endswith('_lmdb'):
255
+ return open_lmdb(source, max_images=max_images)
256
+ else:
257
+ return open_image_folder(source, max_images=max_images)
258
+ elif os.path.isfile(source):
259
+ if os.path.basename(source) == 'cifar-10-python.tar.gz':
260
+ return open_cifar10(source, max_images=max_images)
261
+ elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
262
+ return open_mnist(source, max_images=max_images)
263
+ elif file_ext(source) == 'zip':
264
+ return open_image_zip(source, max_images=max_images)
265
+ else:
266
+ assert False, 'unknown archive type'
267
+ else:
268
+ error(f'Missing input file or directory: {source}')
269
+
270
+ #----------------------------------------------------------------------------
271
+
272
+ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
273
+ dest_ext = file_ext(dest)
274
+
275
+ if dest_ext == 'zip':
276
+ if os.path.dirname(dest) != '':
277
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
278
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
279
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
280
+ zf.writestr(fname, data)
281
+ return '', zip_write_bytes, zf.close
282
+ else:
283
+ # If the output folder already exists, check that is is
284
+ # empty.
285
+ #
286
+ # Note: creating the output directory is not strictly
287
+ # necessary as folder_write_bytes() also mkdirs, but it's better
288
+ # to give an error message earlier in case the dest folder
289
+ # somehow cannot be created.
290
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
291
+ error('--dest folder must be empty')
292
+ os.makedirs(dest, exist_ok=True)
293
+
294
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
295
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
296
+ with open(fname, 'wb') as fout:
297
+ if isinstance(data, str):
298
+ data = data.encode('utf8')
299
+ fout.write(data)
300
+ return dest, folder_write_bytes, lambda: None
301
+
302
+ #----------------------------------------------------------------------------
303
+
304
+ @click.command()
305
+ @click.pass_context
306
+ @click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
307
+ @click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
308
+ @click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
309
+ @click.option('--resize-filter', help='Filter to use when resizing images for output resolution', type=click.Choice(['box', 'lanczos']), default='lanczos', show_default=True)
310
+ @click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
311
+ @click.option('--width', help='Output width', type=int)
312
+ @click.option('--height', help='Output height', type=int)
313
+ def convert_dataset(
314
+ ctx: click.Context,
315
+ source: str,
316
+ dest: str,
317
+ max_images: Optional[int],
318
+ transform: Optional[str],
319
+ resize_filter: str,
320
+ width: Optional[int],
321
+ height: Optional[int]
322
+ ):
323
+ """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
324
+
325
+ The input dataset format is guessed from the --source argument:
326
+
327
+ \b
328
+ --source *_lmdb/ Load LSUN dataset
329
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
330
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
331
+ --source path/ Recursively load all images from path/
332
+ --source dataset.zip Recursively load all images from dataset.zip
333
+
334
+ Specifying the output format and path:
335
+
336
+ \b
337
+ --dest /path/to/dir Save output files under /path/to/dir
338
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
339
+
340
+ The output dataset format can be either an image folder or an uncompressed zip archive.
341
+ Zip archives makes it easier to move datasets around file servers and clusters, and may
342
+ offer better training performance on network file systems.
343
+
344
+ Images within the dataset archive will be stored as uncompressed PNG.
345
+ Uncompresed PNGs can be efficiently decoded in the training loop.
346
+
347
+ Class labels are stored in a file called 'dataset.json' that is stored at the
348
+ dataset root folder. This file has the following structure:
349
+
350
+ \b
351
+ {
352
+ "labels": [
353
+ ["00000/img00000000.png",6],
354
+ ["00000/img00000001.png",9],
355
+ ... repeated for every image in the datase
356
+ ["00049/img00049999.png",1]
357
+ ]
358
+ }
359
+
360
+ If the 'dataset.json' file cannot be found, the dataset is interpreted as
361
+ not containing class labels.
362
+
363
+ Image scale/crop and resolution requirements:
364
+
365
+ Output images must be square-shaped and they must all have the same power-of-two
366
+ dimensions.
367
+
368
+ To scale arbitrary input image size to a specific width and height, use the
369
+ --width and --height options. Output resolution will be either the original
370
+ input resolution (if --width/--height was not specified) or the one specified with
371
+ --width/height.
372
+
373
+ Use the --transform=center-crop or --transform=center-crop-wide options to apply a
374
+ center crop transform on the input image. These options should be used with the
375
+ --width and --height options. For example:
376
+
377
+ \b
378
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
379
+ --transform=center-crop-wide --width 512 --height=384
380
+ """
381
+
382
+ PIL.Image.init() # type: ignore
383
+
384
+ if dest == '':
385
+ ctx.fail('--dest output filename or directory must not be an empty string')
386
+
387
+ num_files, input_iter = open_dataset(source, max_images=max_images)
388
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
389
+
390
+ transform_image = make_transform(transform, width, height, resize_filter)
391
+
392
+ dataset_attrs = None
393
+
394
+ labels = []
395
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
396
+ idx_str = f'{idx:08d}'
397
+ archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
398
+
399
+ # Apply crop and resize.
400
+ img = transform_image(image['img'])
401
+
402
+ # Transform may drop images.
403
+ if img is None:
404
+ continue
405
+
406
+ # Error check to require uniform image attributes across
407
+ # the whole dataset.
408
+ channels = img.shape[2] if img.ndim == 3 else 1
409
+ cur_image_attrs = {
410
+ 'width': img.shape[1],
411
+ 'height': img.shape[0],
412
+ 'channels': channels
413
+ }
414
+ if dataset_attrs is None:
415
+ dataset_attrs = cur_image_attrs
416
+ width = dataset_attrs['width']
417
+ height = dataset_attrs['height']
418
+ if width != height:
419
+ error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
420
+ if dataset_attrs['channels'] not in [1, 3]:
421
+ error('Input images must be stored as RGB or grayscale')
422
+ if width != 2 ** int(np.floor(np.log2(width))):
423
+ error('Image width/height after scale and crop are required to be power-of-two')
424
+ elif dataset_attrs != cur_image_attrs:
425
+ err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
426
+ error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
427
+
428
+ # Save the image as an uncompressed PNG.
429
+ img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB' }[channels])
430
+ image_bits = io.BytesIO()
431
+ img.save(image_bits, format='png', compress_level=0, optimize=False)
432
+ save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
433
+ labels.append([archive_fname, image['label']] if image['label'] is not None else None)
434
+
435
+ metadata = {
436
+ 'labels': labels if all(x is not None for x in labels) else None
437
+ }
438
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
439
+ close_dest()
440
+
441
+ #----------------------------------------------------------------------------
442
+
443
+ if __name__ == "__main__":
444
+ convert_dataset() # pylint: disable=no-value-for-parameter
datasets/dataset_256.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import cv2
10
+ import os
11
+ import numpy as np
12
+ import zipfile
13
+ import PIL.Image
14
+ import json
15
+ import torch
16
+ import dnnlib
17
+ import random
18
+
19
+ try:
20
+ import pyspng
21
+ except ImportError:
22
+ pyspng = None
23
+
24
+ from datasets.mask_generator_256 import RandomMask
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ class Dataset(torch.utils.data.Dataset):
29
+ def __init__(self,
30
+ name, # Name of the dataset.
31
+ raw_shape, # Shape of the raw image data (NCHW).
32
+ max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
33
+ use_labels = False, # Enable conditioning labels? False = label dimension is zero.
34
+ xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
35
+ random_seed = 0, # Random seed to use when applying max_size.
36
+ ):
37
+ self._name = name
38
+ self._raw_shape = list(raw_shape)
39
+ self._use_labels = use_labels
40
+ self._raw_labels = None
41
+ self._label_shape = None
42
+
43
+ # Apply max_size.
44
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
45
+ if (max_size is not None) and (self._raw_idx.size > max_size):
46
+ np.random.RandomState(random_seed).shuffle(self._raw_idx)
47
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
48
+
49
+ # Apply xflip.
50
+ self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
51
+ if xflip:
52
+ self._raw_idx = np.tile(self._raw_idx, 2)
53
+ self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
54
+
55
+ def _get_raw_labels(self):
56
+ if self._raw_labels is None:
57
+ self._raw_labels = self._load_raw_labels() if self._use_labels else None
58
+ if self._raw_labels is None:
59
+ self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
60
+ assert isinstance(self._raw_labels, np.ndarray)
61
+ assert self._raw_labels.shape[0] == self._raw_shape[0]
62
+ assert self._raw_labels.dtype in [np.float32, np.int64]
63
+ if self._raw_labels.dtype == np.int64:
64
+ assert self._raw_labels.ndim == 1
65
+ assert np.all(self._raw_labels >= 0)
66
+ return self._raw_labels
67
+
68
+ def close(self): # to be overridden by subclass
69
+ pass
70
+
71
+ def _load_raw_image(self, raw_idx): # to be overridden by subclass
72
+ raise NotImplementedError
73
+
74
+ def _load_raw_labels(self): # to be overridden by subclass
75
+ raise NotImplementedError
76
+
77
+ def __getstate__(self):
78
+ return dict(self.__dict__, _raw_labels=None)
79
+
80
+ def __del__(self):
81
+ try:
82
+ self.close()
83
+ except:
84
+ pass
85
+
86
+ def __len__(self):
87
+ return self._raw_idx.size
88
+
89
+ def __getitem__(self, idx):
90
+ image = self._load_raw_image(self._raw_idx[idx])
91
+ assert isinstance(image, np.ndarray)
92
+ assert list(image.shape) == self.image_shape
93
+ assert image.dtype == np.uint8
94
+ if self._xflip[idx]:
95
+ assert image.ndim == 3 # CHW
96
+ image = image[:, :, ::-1]
97
+ return image.copy(), self.get_label(idx)
98
+
99
+ def get_label(self, idx):
100
+ label = self._get_raw_labels()[self._raw_idx[idx]]
101
+ if label.dtype == np.int64:
102
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
103
+ onehot[label] = 1
104
+ label = onehot
105
+ return label.copy()
106
+
107
+ def get_details(self, idx):
108
+ d = dnnlib.EasyDict()
109
+ d.raw_idx = int(self._raw_idx[idx])
110
+ d.xflip = (int(self._xflip[idx]) != 0)
111
+ d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
112
+ return d
113
+
114
+ @property
115
+ def name(self):
116
+ return self._name
117
+
118
+ @property
119
+ def image_shape(self):
120
+ return list(self._raw_shape[1:])
121
+
122
+ @property
123
+ def num_channels(self):
124
+ assert len(self.image_shape) == 3 # CHW
125
+ return self.image_shape[0]
126
+
127
+ @property
128
+ def resolution(self):
129
+ assert len(self.image_shape) == 3 # CHW
130
+ assert self.image_shape[1] == self.image_shape[2]
131
+ return self.image_shape[1]
132
+
133
+ @property
134
+ def label_shape(self):
135
+ if self._label_shape is None:
136
+ raw_labels = self._get_raw_labels()
137
+ if raw_labels.dtype == np.int64:
138
+ self._label_shape = [int(np.max(raw_labels)) + 1]
139
+ else:
140
+ self._label_shape = raw_labels.shape[1:]
141
+ return list(self._label_shape)
142
+
143
+ @property
144
+ def label_dim(self):
145
+ assert len(self.label_shape) == 1
146
+ return self.label_shape[0]
147
+
148
+ @property
149
+ def has_labels(self):
150
+ return any(x != 0 for x in self.label_shape)
151
+
152
+ @property
153
+ def has_onehot_labels(self):
154
+ return self._get_raw_labels().dtype == np.int64
155
+
156
+
157
+ #----------------------------------------------------------------------------
158
+
159
+
160
+ class ImageFolderMaskDataset(Dataset):
161
+ def __init__(self,
162
+ path, # Path to directory or zip.
163
+ resolution = None, # Ensure specific resolution, None = highest available.
164
+ hole_range=[0,1],
165
+ **super_kwargs, # Additional arguments for the Dataset base class.
166
+ ):
167
+ self._path = path
168
+ self._zipfile = None
169
+ self._hole_range = hole_range
170
+
171
+ if os.path.isdir(self._path):
172
+ self._type = 'dir'
173
+ self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
174
+ elif self._file_ext(self._path) == '.zip':
175
+ self._type = 'zip'
176
+ self._all_fnames = set(self._get_zipfile().namelist())
177
+ else:
178
+ raise IOError('Path must point to a directory or zip')
179
+
180
+ PIL.Image.init()
181
+ self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
182
+ if len(self._image_fnames) == 0:
183
+ raise IOError('No image files found in the specified path')
184
+
185
+ name = os.path.splitext(os.path.basename(self._path))[0]
186
+ raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
187
+ if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
188
+ raise IOError('Image files do not match the specified resolution')
189
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
190
+
191
+ @staticmethod
192
+ def _file_ext(fname):
193
+ return os.path.splitext(fname)[1].lower()
194
+
195
+ def _get_zipfile(self):
196
+ assert self._type == 'zip'
197
+ if self._zipfile is None:
198
+ self._zipfile = zipfile.ZipFile(self._path)
199
+ return self._zipfile
200
+
201
+ def _open_file(self, fname):
202
+ if self._type == 'dir':
203
+ return open(os.path.join(self._path, fname), 'rb')
204
+ if self._type == 'zip':
205
+ return self._get_zipfile().open(fname, 'r')
206
+ return None
207
+
208
+ def close(self):
209
+ try:
210
+ if self._zipfile is not None:
211
+ self._zipfile.close()
212
+ finally:
213
+ self._zipfile = None
214
+
215
+ def __getstate__(self):
216
+ return dict(super().__getstate__(), _zipfile=None)
217
+
218
+ def _load_raw_image(self, raw_idx):
219
+ fname = self._image_fnames[raw_idx]
220
+ with self._open_file(fname) as f:
221
+ if pyspng is not None and self._file_ext(fname) == '.png':
222
+ image = pyspng.load(f.read())
223
+ else:
224
+ image = np.array(PIL.Image.open(f))
225
+ if image.ndim == 2:
226
+ image = image[:, :, np.newaxis] # HW => HWC
227
+
228
+ # for grayscale image
229
+ if image.shape[2] == 1:
230
+ image = np.repeat(image, 3, axis=2)
231
+
232
+ # restricted to 256x256
233
+ res = 256
234
+ H, W, C = image.shape
235
+ if H < res or W < res:
236
+ top = 0
237
+ bottom = max(0, res - H)
238
+ left = 0
239
+ right = max(0, res - W)
240
+ image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_REFLECT)
241
+ H, W, C = image.shape
242
+ h = random.randint(0, H - res)
243
+ w = random.randint(0, W - res)
244
+ image = image[h:h+res, w:w+res, :]
245
+
246
+ image = np.ascontiguousarray(image.transpose(2, 0, 1)) # HWC => CHW
247
+
248
+ return image
249
+
250
+ def _load_raw_labels(self):
251
+ fname = 'labels.json'
252
+ if fname not in self._all_fnames:
253
+ return None
254
+ with self._open_file(fname) as f:
255
+ labels = json.load(f)['labels']
256
+ if labels is None:
257
+ return None
258
+ labels = dict(labels)
259
+ labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
260
+ labels = np.array(labels)
261
+ labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
262
+ return labels
263
+
264
+ def __getitem__(self, idx):
265
+ image = self._load_raw_image(self._raw_idx[idx])
266
+
267
+ assert isinstance(image, np.ndarray)
268
+ assert list(image.shape) == self.image_shape
269
+ assert image.dtype == np.uint8
270
+ if self._xflip[idx]:
271
+ assert image.ndim == 3 # CHW
272
+ image = image[:, :, ::-1]
273
+ mask = RandomMask(image.shape[-1], hole_range=self._hole_range) # hole as 0, reserved as 1
274
+ return image.copy(), mask, self.get_label(idx)
275
+
276
+
277
+ if __name__ == '__main__':
278
+ res = 256
279
+ dpath = '/data/liwenbo/datasets/Places365/standard/val_256'
280
+ D = ImageFolderMaskDataset(path=dpath)
281
+ print(D.__len__())
282
+ for i in range(D.__len__()):
283
+ print(i)
284
+ a, b, c = D.__getitem__(i)
285
+ if a.shape != (3, 256, 256):
286
+ print(i, a.shape)
datasets/dataset_256_val.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import numpy as np
11
+ import zipfile
12
+ import PIL.Image
13
+ import cv2
14
+ import json
15
+ import torch
16
+ import dnnlib
17
+ import glob
18
+
19
+ try:
20
+ import pyspng
21
+ except ImportError:
22
+ pyspng = None
23
+
24
+ from datasets.mask_generator_256 import RandomMask
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ class Dataset(torch.utils.data.Dataset):
29
+ def __init__(self,
30
+ name, # Name of the dataset.
31
+ raw_shape, # Shape of the raw image data (NCHW).
32
+ max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
33
+ use_labels = False, # Enable conditioning labels? False = label dimension is zero.
34
+ xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
35
+ random_seed = 0, # Random seed to use when applying max_size.
36
+ ):
37
+ self._name = name
38
+ self._raw_shape = list(raw_shape)
39
+ self._use_labels = use_labels
40
+ self._raw_labels = None
41
+ self._label_shape = None
42
+
43
+ # Apply max_size.
44
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
45
+ if (max_size is not None) and (self._raw_idx.size > max_size):
46
+ np.random.RandomState(random_seed).shuffle(self._raw_idx)
47
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
48
+
49
+ # Apply xflip.
50
+ self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
51
+ if xflip:
52
+ self._raw_idx = np.tile(self._raw_idx, 2)
53
+ self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
54
+
55
+ def _get_raw_labels(self):
56
+ if self._raw_labels is None:
57
+ self._raw_labels = self._load_raw_labels() if self._use_labels else None
58
+ if self._raw_labels is None:
59
+ self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
60
+ assert isinstance(self._raw_labels, np.ndarray)
61
+ assert self._raw_labels.shape[0] == self._raw_shape[0]
62
+ assert self._raw_labels.dtype in [np.float32, np.int64]
63
+ if self._raw_labels.dtype == np.int64:
64
+ assert self._raw_labels.ndim == 1
65
+ assert np.all(self._raw_labels >= 0)
66
+ return self._raw_labels
67
+
68
+ def close(self): # to be overridden by subclass
69
+ pass
70
+
71
+ def _load_raw_image(self, raw_idx): # to be overridden by subclass
72
+ raise NotImplementedError
73
+
74
+ def _load_raw_labels(self): # to be overridden by subclass
75
+ raise NotImplementedError
76
+
77
+ def __getstate__(self):
78
+ return dict(self.__dict__, _raw_labels=None)
79
+
80
+ def __del__(self):
81
+ try:
82
+ self.close()
83
+ except:
84
+ pass
85
+
86
+ def __len__(self):
87
+ return self._raw_idx.size
88
+
89
+ def __getitem__(self, idx):
90
+ image = self._load_raw_image(self._raw_idx[idx])
91
+ assert isinstance(image, np.ndarray)
92
+ assert list(image.shape) == self.image_shape
93
+ assert image.dtype == np.uint8
94
+ if self._xflip[idx]:
95
+ assert image.ndim == 3 # CHW
96
+ image = image[:, :, ::-1]
97
+ return image.copy(), self.get_label(idx)
98
+
99
+ def get_label(self, idx):
100
+ label = self._get_raw_labels()[self._raw_idx[idx]]
101
+ if label.dtype == np.int64:
102
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
103
+ onehot[label] = 1
104
+ label = onehot
105
+ return label.copy()
106
+
107
+ def get_details(self, idx):
108
+ d = dnnlib.EasyDict()
109
+ d.raw_idx = int(self._raw_idx[idx])
110
+ d.xflip = (int(self._xflip[idx]) != 0)
111
+ d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
112
+ return d
113
+
114
+ @property
115
+ def name(self):
116
+ return self._name
117
+
118
+ @property
119
+ def image_shape(self):
120
+ return list(self._raw_shape[1:])
121
+
122
+ @property
123
+ def num_channels(self):
124
+ assert len(self.image_shape) == 3 # CHW
125
+ return self.image_shape[0]
126
+
127
+ @property
128
+ def resolution(self):
129
+ assert len(self.image_shape) == 3 # CHW
130
+ assert self.image_shape[1] == self.image_shape[2]
131
+ return self.image_shape[1]
132
+
133
+ @property
134
+ def label_shape(self):
135
+ if self._label_shape is None:
136
+ raw_labels = self._get_raw_labels()
137
+ if raw_labels.dtype == np.int64:
138
+ self._label_shape = [int(np.max(raw_labels)) + 1]
139
+ else:
140
+ self._label_shape = raw_labels.shape[1:]
141
+ return list(self._label_shape)
142
+
143
+ @property
144
+ def label_dim(self):
145
+ assert len(self.label_shape) == 1
146
+ return self.label_shape[0]
147
+
148
+ @property
149
+ def has_labels(self):
150
+ return any(x != 0 for x in self.label_shape)
151
+
152
+ @property
153
+ def has_onehot_labels(self):
154
+ return self._get_raw_labels().dtype == np.int64
155
+
156
+
157
+ #----------------------------------------------------------------------------
158
+
159
+
160
+ class ImageFolderMaskDataset(Dataset):
161
+ def __init__(self,
162
+ path, # Path to directory or zip.
163
+ resolution = None, # Ensure specific resolution, None = highest available.
164
+ hole_range=[0,1],
165
+ **super_kwargs, # Additional arguments for the Dataset base class.
166
+ ):
167
+ self._path = path
168
+ self._zipfile = None
169
+ self._hole_range = hole_range
170
+
171
+ if os.path.isdir(self._path):
172
+ self._type = 'dir'
173
+ self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
174
+ elif self._file_ext(self._path) == '.zip':
175
+ self._type = 'zip'
176
+ self._all_fnames = set(self._get_zipfile().namelist())
177
+ else:
178
+ raise IOError('Path must point to a directory or zip')
179
+
180
+ PIL.Image.init()
181
+ self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
182
+ if len(self._image_fnames) == 0:
183
+ raise IOError('No image files found in the specified path')
184
+
185
+ name = os.path.splitext(os.path.basename(self._path))[0]
186
+ raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
187
+ if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
188
+ raise IOError('Image files do not match the specified resolution')
189
+ self._load_mask()
190
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
191
+
192
+ def _load_mask(self, mpath='/data/liwenbo/datasets/Places365/standard/masks_val_256_eval'):
193
+ self.masks = sorted(glob.glob(mpath + '/*.png'))
194
+
195
+ @staticmethod
196
+ def _file_ext(fname):
197
+ return os.path.splitext(fname)[1].lower()
198
+
199
+ def _get_zipfile(self):
200
+ assert self._type == 'zip'
201
+ if self._zipfile is None:
202
+ self._zipfile = zipfile.ZipFile(self._path)
203
+ return self._zipfile
204
+
205
+ def _open_file(self, fname):
206
+ if self._type == 'dir':
207
+ return open(os.path.join(self._path, fname), 'rb')
208
+ if self._type == 'zip':
209
+ return self._get_zipfile().open(fname, 'r')
210
+ return None
211
+
212
+ def close(self):
213
+ try:
214
+ if self._zipfile is not None:
215
+ self._zipfile.close()
216
+ finally:
217
+ self._zipfile = None
218
+
219
+ def __getstate__(self):
220
+ return dict(super().__getstate__(), _zipfile=None)
221
+
222
+ def _load_raw_image(self, raw_idx):
223
+ fname = self._image_fnames[raw_idx]
224
+ with self._open_file(fname) as f:
225
+ if pyspng is not None and self._file_ext(fname) == '.png':
226
+ image = pyspng.load(f.read())
227
+ else:
228
+ image = np.array(PIL.Image.open(f))
229
+ if image.ndim == 2:
230
+ image = image[:, :, np.newaxis] # HW => HWC
231
+
232
+ # for grayscale image
233
+ if image.shape[2] == 1:
234
+ image = np.repeat(image, 3, axis=2)
235
+
236
+ # restricted to 256x256
237
+ res = 256
238
+ H, W, C = image.shape
239
+ if H < res or W < res:
240
+ top = 0
241
+ bottom = max(0, res - H)
242
+ left = 0
243
+ right = max(0, res - W)
244
+ image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_REFLECT)
245
+ H, W, C = image.shape
246
+ h = (H - res) // 2
247
+ w = (W - res) // 2
248
+ image = image[h:h+res, w:w+res, :]
249
+
250
+ image = np.ascontiguousarray(image.transpose(2, 0, 1)) # HWC => CHW
251
+ return image
252
+
253
+ def _load_raw_labels(self):
254
+ fname = 'labels.json'
255
+ if fname not in self._all_fnames:
256
+ return None
257
+ with self._open_file(fname) as f:
258
+ labels = json.load(f)['labels']
259
+ if labels is None:
260
+ return None
261
+ labels = dict(labels)
262
+ labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
263
+ labels = np.array(labels)
264
+ labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
265
+ return labels
266
+
267
+ def __getitem__(self, idx):
268
+ image = self._load_raw_image(self._raw_idx[idx])
269
+
270
+ # for grayscale image
271
+ if image.shape[0] == 1:
272
+ image = np.repeat(image, 3, axis=0)
273
+
274
+ assert isinstance(image, np.ndarray)
275
+ assert list(image.shape) == self.image_shape
276
+ assert image.dtype == np.uint8
277
+ if self._xflip[idx]:
278
+ assert image.ndim == 3 # CHW
279
+ image = image[:, :, ::-1]
280
+ # mask = RandomMask(image.shape[-1], hole_range=self._hole_range) # hole as 0, reserved as 1
281
+ mask = cv2.imread(self.masks[idx], cv2.IMREAD_GRAYSCALE).astype(np.float32)[np.newaxis, :, :] / 255.0
282
+ return image.copy(), mask, self.get_label(idx)
datasets/dataset_512.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import cv2
10
+ import os
11
+ import numpy as np
12
+ import zipfile
13
+ import PIL.Image
14
+ import json
15
+ import torch
16
+ import dnnlib
17
+ import random
18
+
19
+ try:
20
+ import pyspng
21
+ except ImportError:
22
+ pyspng = None
23
+
24
+ from datasets.mask_generator_512 import RandomMask
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ class Dataset(torch.utils.data.Dataset):
29
+ def __init__(self,
30
+ name, # Name of the dataset.
31
+ raw_shape, # Shape of the raw image data (NCHW).
32
+ max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
33
+ use_labels = False, # Enable conditioning labels? False = label dimension is zero.
34
+ xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
35
+ random_seed = 0, # Random seed to use when applying max_size.
36
+ ):
37
+ self._name = name
38
+ self._raw_shape = list(raw_shape)
39
+ self._use_labels = use_labels
40
+ self._raw_labels = None
41
+ self._label_shape = None
42
+
43
+ # Apply max_size.
44
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
45
+ if (max_size is not None) and (self._raw_idx.size > max_size):
46
+ np.random.RandomState(random_seed).shuffle(self._raw_idx)
47
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
48
+
49
+ # Apply xflip.
50
+ self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
51
+ if xflip:
52
+ self._raw_idx = np.tile(self._raw_idx, 2)
53
+ self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
54
+
55
+ def _get_raw_labels(self):
56
+ if self._raw_labels is None:
57
+ self._raw_labels = self._load_raw_labels() if self._use_labels else None
58
+ if self._raw_labels is None:
59
+ self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
60
+ assert isinstance(self._raw_labels, np.ndarray)
61
+ assert self._raw_labels.shape[0] == self._raw_shape[0]
62
+ assert self._raw_labels.dtype in [np.float32, np.int64]
63
+ if self._raw_labels.dtype == np.int64:
64
+ assert self._raw_labels.ndim == 1
65
+ assert np.all(self._raw_labels >= 0)
66
+ return self._raw_labels
67
+
68
+ def close(self): # to be overridden by subclass
69
+ pass
70
+
71
+ def _load_raw_image(self, raw_idx): # to be overridden by subclass
72
+ raise NotImplementedError
73
+
74
+ def _load_raw_labels(self): # to be overridden by subclass
75
+ raise NotImplementedError
76
+
77
+ def __getstate__(self):
78
+ return dict(self.__dict__, _raw_labels=None)
79
+
80
+ def __del__(self):
81
+ try:
82
+ self.close()
83
+ except:
84
+ pass
85
+
86
+ def __len__(self):
87
+ return self._raw_idx.size
88
+
89
+ def __getitem__(self, idx):
90
+ image = self._load_raw_image(self._raw_idx[idx])
91
+ assert isinstance(image, np.ndarray)
92
+ assert list(image.shape) == self.image_shape
93
+ assert image.dtype == np.uint8
94
+ if self._xflip[idx]:
95
+ assert image.ndim == 3 # CHW
96
+ image = image[:, :, ::-1]
97
+ return image.copy(), self.get_label(idx)
98
+
99
+ def get_label(self, idx):
100
+ label = self._get_raw_labels()[self._raw_idx[idx]]
101
+ if label.dtype == np.int64:
102
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
103
+ onehot[label] = 1
104
+ label = onehot
105
+ return label.copy()
106
+
107
+ def get_details(self, idx):
108
+ d = dnnlib.EasyDict()
109
+ d.raw_idx = int(self._raw_idx[idx])
110
+ d.xflip = (int(self._xflip[idx]) != 0)
111
+ d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
112
+ return d
113
+
114
+ @property
115
+ def name(self):
116
+ return self._name
117
+
118
+ @property
119
+ def image_shape(self):
120
+ return list(self._raw_shape[1:])
121
+
122
+ @property
123
+ def num_channels(self):
124
+ assert len(self.image_shape) == 3 # CHW
125
+ return self.image_shape[0]
126
+
127
+ @property
128
+ def resolution(self):
129
+ assert len(self.image_shape) == 3 # CHW
130
+ assert self.image_shape[1] == self.image_shape[2]
131
+ return self.image_shape[1]
132
+
133
+ @property
134
+ def label_shape(self):
135
+ if self._label_shape is None:
136
+ raw_labels = self._get_raw_labels()
137
+ if raw_labels.dtype == np.int64:
138
+ self._label_shape = [int(np.max(raw_labels)) + 1]
139
+ else:
140
+ self._label_shape = raw_labels.shape[1:]
141
+ return list(self._label_shape)
142
+
143
+ @property
144
+ def label_dim(self):
145
+ assert len(self.label_shape) == 1
146
+ return self.label_shape[0]
147
+
148
+ @property
149
+ def has_labels(self):
150
+ return any(x != 0 for x in self.label_shape)
151
+
152
+ @property
153
+ def has_onehot_labels(self):
154
+ return self._get_raw_labels().dtype == np.int64
155
+
156
+
157
+ #----------------------------------------------------------------------------
158
+
159
+
160
+ class ImageFolderMaskDataset(Dataset):
161
+ def __init__(self,
162
+ path, # Path to directory or zip.
163
+ resolution = None, # Ensure specific resolution, None = highest available.
164
+ hole_range=[0,1],
165
+ **super_kwargs, # Additional arguments for the Dataset base class.
166
+ ):
167
+ self._path = path
168
+ self._zipfile = None
169
+ self._hole_range = hole_range
170
+
171
+ if os.path.isdir(self._path):
172
+ self._type = 'dir'
173
+ self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
174
+ elif self._file_ext(self._path) == '.zip':
175
+ self._type = 'zip'
176
+ self._all_fnames = set(self._get_zipfile().namelist())
177
+ else:
178
+ raise IOError('Path must point to a directory or zip')
179
+
180
+ PIL.Image.init()
181
+ self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
182
+ if len(self._image_fnames) == 0:
183
+ raise IOError('No image files found in the specified path')
184
+
185
+ name = os.path.splitext(os.path.basename(self._path))[0]
186
+ raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
187
+ if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
188
+ raise IOError('Image files do not match the specified resolution')
189
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
190
+
191
+ @staticmethod
192
+ def _file_ext(fname):
193
+ return os.path.splitext(fname)[1].lower()
194
+
195
+ def _get_zipfile(self):
196
+ assert self._type == 'zip'
197
+ if self._zipfile is None:
198
+ self._zipfile = zipfile.ZipFile(self._path)
199
+ return self._zipfile
200
+
201
+ def _open_file(self, fname):
202
+ if self._type == 'dir':
203
+ return open(os.path.join(self._path, fname), 'rb')
204
+ if self._type == 'zip':
205
+ return self._get_zipfile().open(fname, 'r')
206
+ return None
207
+
208
+ def close(self):
209
+ try:
210
+ if self._zipfile is not None:
211
+ self._zipfile.close()
212
+ finally:
213
+ self._zipfile = None
214
+
215
+ def __getstate__(self):
216
+ return dict(super().__getstate__(), _zipfile=None)
217
+
218
+ def _load_raw_image(self, raw_idx):
219
+ fname = self._image_fnames[raw_idx]
220
+ with self._open_file(fname) as f:
221
+ if pyspng is not None and self._file_ext(fname) == '.png':
222
+ image = pyspng.load(f.read())
223
+ else:
224
+ image = np.array(PIL.Image.open(f))
225
+ if image.ndim == 2:
226
+ image = image[:, :, np.newaxis] # HW => HWC
227
+
228
+ # for grayscale image
229
+ if image.shape[2] == 1:
230
+ image = np.repeat(image, 3, axis=2)
231
+
232
+ # restricted to 512x512
233
+ res = 512
234
+ H, W, C = image.shape
235
+ if H < res or W < res:
236
+ top = 0
237
+ bottom = max(0, res - H)
238
+ left = 0
239
+ right = max(0, res - W)
240
+ image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_REFLECT)
241
+ H, W, C = image.shape
242
+ h = random.randint(0, H - res)
243
+ w = random.randint(0, W - res)
244
+ image = image[h:h+res, w:w+res, :]
245
+
246
+ image = np.ascontiguousarray(image.transpose(2, 0, 1)) # HWC => CHW
247
+
248
+ return image
249
+
250
+ def _load_raw_labels(self):
251
+ fname = 'labels.json'
252
+ if fname not in self._all_fnames:
253
+ return None
254
+ with self._open_file(fname) as f:
255
+ labels = json.load(f)['labels']
256
+ if labels is None:
257
+ return None
258
+ labels = dict(labels)
259
+ labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
260
+ labels = np.array(labels)
261
+ labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
262
+ return labels
263
+
264
+ def __getitem__(self, idx):
265
+ image = self._load_raw_image(self._raw_idx[idx])
266
+
267
+ assert isinstance(image, np.ndarray)
268
+ assert list(image.shape) == self.image_shape
269
+ assert image.dtype == np.uint8
270
+ if self._xflip[idx]:
271
+ assert image.ndim == 3 # CHW
272
+ image = image[:, :, ::-1]
273
+ mask = RandomMask(image.shape[-1], hole_range=self._hole_range) # hole as 0, reserved as 1
274
+ return image.copy(), mask, self.get_label(idx)
275
+
276
+
277
+ if __name__ == '__main__':
278
+ res = 512
279
+ dpath = '/data/liwenbo/datasets/Places365/standard/val_large'
280
+ D = ImageFolderMaskDataset(path=dpath)
281
+ print(D.__len__())
282
+ for i in range(D.__len__()):
283
+ print(i)
284
+ a, b, c = D.__getitem__(i)
285
+ if a.shape != (3, 512, 512):
286
+ print(i, a.shape)
datasets/dataset_512_val.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import numpy as np
11
+ import zipfile
12
+ import PIL.Image
13
+ import cv2
14
+ import json
15
+ import torch
16
+ import dnnlib
17
+ import glob
18
+
19
+ try:
20
+ import pyspng
21
+ except ImportError:
22
+ pyspng = None
23
+
24
+ from datasets.mask_generator_512 import RandomMask
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ class Dataset(torch.utils.data.Dataset):
29
+ def __init__(self,
30
+ name, # Name of the dataset.
31
+ raw_shape, # Shape of the raw image data (NCHW).
32
+ max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
33
+ use_labels = False, # Enable conditioning labels? False = label dimension is zero.
34
+ xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
35
+ random_seed = 0, # Random seed to use when applying max_size.
36
+ ):
37
+ self._name = name
38
+ self._raw_shape = list(raw_shape)
39
+ self._use_labels = use_labels
40
+ self._raw_labels = None
41
+ self._label_shape = None
42
+
43
+ # Apply max_size.
44
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
45
+ if (max_size is not None) and (self._raw_idx.size > max_size):
46
+ np.random.RandomState(random_seed).shuffle(self._raw_idx)
47
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
48
+
49
+ # Apply xflip.
50
+ self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
51
+ if xflip:
52
+ self._raw_idx = np.tile(self._raw_idx, 2)
53
+ self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
54
+
55
+ def _get_raw_labels(self):
56
+ if self._raw_labels is None:
57
+ self._raw_labels = self._load_raw_labels() if self._use_labels else None
58
+ if self._raw_labels is None:
59
+ self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
60
+ assert isinstance(self._raw_labels, np.ndarray)
61
+ assert self._raw_labels.shape[0] == self._raw_shape[0]
62
+ assert self._raw_labels.dtype in [np.float32, np.int64]
63
+ if self._raw_labels.dtype == np.int64:
64
+ assert self._raw_labels.ndim == 1
65
+ assert np.all(self._raw_labels >= 0)
66
+ return self._raw_labels
67
+
68
+ def close(self): # to be overridden by subclass
69
+ pass
70
+
71
+ def _load_raw_image(self, raw_idx): # to be overridden by subclass
72
+ raise NotImplementedError
73
+
74
+ def _load_raw_labels(self): # to be overridden by subclass
75
+ raise NotImplementedError
76
+
77
+ def __getstate__(self):
78
+ return dict(self.__dict__, _raw_labels=None)
79
+
80
+ def __del__(self):
81
+ try:
82
+ self.close()
83
+ except:
84
+ pass
85
+
86
+ def __len__(self):
87
+ return self._raw_idx.size
88
+
89
+ def __getitem__(self, idx):
90
+ image = self._load_raw_image(self._raw_idx[idx])
91
+ assert isinstance(image, np.ndarray)
92
+ assert list(image.shape) == self.image_shape
93
+ assert image.dtype == np.uint8
94
+ if self._xflip[idx]:
95
+ assert image.ndim == 3 # CHW
96
+ image = image[:, :, ::-1]
97
+ return image.copy(), self.get_label(idx)
98
+
99
+ def get_label(self, idx):
100
+ label = self._get_raw_labels()[self._raw_idx[idx]]
101
+ if label.dtype == np.int64:
102
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
103
+ onehot[label] = 1
104
+ label = onehot
105
+ return label.copy()
106
+
107
+ def get_details(self, idx):
108
+ d = dnnlib.EasyDict()
109
+ d.raw_idx = int(self._raw_idx[idx])
110
+ d.xflip = (int(self._xflip[idx]) != 0)
111
+ d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
112
+ return d
113
+
114
+ @property
115
+ def name(self):
116
+ return self._name
117
+
118
+ @property
119
+ def image_shape(self):
120
+ return list(self._raw_shape[1:])
121
+
122
+ @property
123
+ def num_channels(self):
124
+ assert len(self.image_shape) == 3 # CHW
125
+ return self.image_shape[0]
126
+
127
+ @property
128
+ def resolution(self):
129
+ assert len(self.image_shape) == 3 # CHW
130
+ assert self.image_shape[1] == self.image_shape[2]
131
+ return self.image_shape[1]
132
+
133
+ @property
134
+ def label_shape(self):
135
+ if self._label_shape is None:
136
+ raw_labels = self._get_raw_labels()
137
+ if raw_labels.dtype == np.int64:
138
+ self._label_shape = [int(np.max(raw_labels)) + 1]
139
+ else:
140
+ self._label_shape = raw_labels.shape[1:]
141
+ return list(self._label_shape)
142
+
143
+ @property
144
+ def label_dim(self):
145
+ assert len(self.label_shape) == 1
146
+ return self.label_shape[0]
147
+
148
+ @property
149
+ def has_labels(self):
150
+ return any(x != 0 for x in self.label_shape)
151
+
152
+ @property
153
+ def has_onehot_labels(self):
154
+ return self._get_raw_labels().dtype == np.int64
155
+
156
+
157
+ #----------------------------------------------------------------------------
158
+
159
+
160
+ class ImageFolderMaskDataset(Dataset):
161
+ def __init__(self,
162
+ path, # Path to directory or zip.
163
+ resolution = None, # Ensure specific resolution, None = highest available.
164
+ hole_range=[0,1],
165
+ **super_kwargs, # Additional arguments for the Dataset base class.
166
+ ):
167
+ self._path = path
168
+ self._zipfile = None
169
+ self._hole_range = hole_range
170
+
171
+ if os.path.isdir(self._path):
172
+ self._type = 'dir'
173
+ self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
174
+ elif self._file_ext(self._path) == '.zip':
175
+ self._type = 'zip'
176
+ self._all_fnames = set(self._get_zipfile().namelist())
177
+ else:
178
+ raise IOError('Path must point to a directory or zip')
179
+
180
+ PIL.Image.init()
181
+ self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
182
+ if len(self._image_fnames) == 0:
183
+ raise IOError('No image files found in the specified path')
184
+
185
+ name = os.path.splitext(os.path.basename(self._path))[0]
186
+ raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
187
+ if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
188
+ raise IOError('Image files do not match the specified resolution')
189
+ self._load_mask()
190
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
191
+
192
+ def _load_mask(self, mpath='/data/liwenbo/datasets/Places365/standard/masks_val_512_eval'):
193
+ self.masks = sorted(glob.glob(mpath + '/*.png'))
194
+
195
+ @staticmethod
196
+ def _file_ext(fname):
197
+ return os.path.splitext(fname)[1].lower()
198
+
199
+ def _get_zipfile(self):
200
+ assert self._type == 'zip'
201
+ if self._zipfile is None:
202
+ self._zipfile = zipfile.ZipFile(self._path)
203
+ return self._zipfile
204
+
205
+ def _open_file(self, fname):
206
+ if self._type == 'dir':
207
+ return open(os.path.join(self._path, fname), 'rb')
208
+ if self._type == 'zip':
209
+ return self._get_zipfile().open(fname, 'r')
210
+ return None
211
+
212
+ def close(self):
213
+ try:
214
+ if self._zipfile is not None:
215
+ self._zipfile.close()
216
+ finally:
217
+ self._zipfile = None
218
+
219
+ def __getstate__(self):
220
+ return dict(super().__getstate__(), _zipfile=None)
221
+
222
+ def _load_raw_image(self, raw_idx):
223
+ fname = self._image_fnames[raw_idx]
224
+ with self._open_file(fname) as f:
225
+ if pyspng is not None and self._file_ext(fname) == '.png':
226
+ image = pyspng.load(f.read())
227
+ else:
228
+ image = np.array(PIL.Image.open(f))
229
+ if image.ndim == 2:
230
+ image = image[:, :, np.newaxis] # HW => HWC
231
+
232
+ # for grayscale image
233
+ if image.shape[2] == 1:
234
+ image = np.repeat(image, 3, axis=2)
235
+
236
+ # restricted to 512x512
237
+ res = 512
238
+ H, W, C = image.shape
239
+ if H < res or W < res:
240
+ top = 0
241
+ bottom = max(0, res - H)
242
+ left = 0
243
+ right = max(0, res - W)
244
+ image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_REFLECT)
245
+ H, W, C = image.shape
246
+ h = (H - res) // 2
247
+ w = (W - res) // 2
248
+ image = image[h:h+res, w:w+res, :]
249
+
250
+ image = np.ascontiguousarray(image.transpose(2, 0, 1)) # HWC => CHW
251
+ return image
252
+
253
+ def _load_raw_labels(self):
254
+ fname = 'labels.json'
255
+ if fname not in self._all_fnames:
256
+ return None
257
+ with self._open_file(fname) as f:
258
+ labels = json.load(f)['labels']
259
+ if labels is None:
260
+ return None
261
+ labels = dict(labels)
262
+ labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
263
+ labels = np.array(labels)
264
+ labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
265
+ return labels
266
+
267
+ def __getitem__(self, idx):
268
+ image = self._load_raw_image(self._raw_idx[idx])
269
+
270
+ # for grayscale image
271
+ if image.shape[0] == 1:
272
+ image = np.repeat(image, 3, axis=0)
273
+
274
+ assert isinstance(image, np.ndarray)
275
+ assert list(image.shape) == self.image_shape
276
+ assert image.dtype == np.uint8
277
+ if self._xflip[idx]:
278
+ assert image.ndim == 3 # CHW
279
+ image = image[:, :, ::-1]
280
+ # mask = RandomMask(image.shape[-1], hole_range=self._hole_range) # hole as 0, reserved as 1
281
+ mask = cv2.imread(self.masks[idx], cv2.IMREAD_GRAYSCALE).astype(np.float32)[np.newaxis, :, :] / 255.0
282
+ return image.copy(), mask, self.get_label(idx)
datasets/mask_generator_256.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image, ImageDraw
3
+ import math
4
+ import random
5
+
6
+
7
+ def RandomBrush(
8
+ max_tries,
9
+ s,
10
+ min_num_vertex = 4,
11
+ max_num_vertex = 18,
12
+ mean_angle = 2*math.pi / 5,
13
+ angle_range = 2*math.pi / 15,
14
+ min_width = 12,
15
+ max_width = 48):
16
+ H, W = s, s
17
+ average_radius = math.sqrt(H*H+W*W) / 8
18
+ mask = Image.new('L', (W, H), 0)
19
+ for _ in range(np.random.randint(max_tries)):
20
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
21
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
22
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
23
+ angles = []
24
+ vertex = []
25
+ for i in range(num_vertex):
26
+ if i % 2 == 0:
27
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
28
+ else:
29
+ angles.append(np.random.uniform(angle_min, angle_max))
30
+
31
+ h, w = mask.size
32
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
33
+ for i in range(num_vertex):
34
+ r = np.clip(
35
+ np.random.normal(loc=average_radius, scale=average_radius//2),
36
+ 0, 2*average_radius)
37
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
38
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
39
+ vertex.append((int(new_x), int(new_y)))
40
+
41
+ draw = ImageDraw.Draw(mask)
42
+ width = int(np.random.uniform(min_width, max_width))
43
+ draw.line(vertex, fill=1, width=width)
44
+ for v in vertex:
45
+ draw.ellipse((v[0] - width//2,
46
+ v[1] - width//2,
47
+ v[0] + width//2,
48
+ v[1] + width//2),
49
+ fill=1)
50
+ if np.random.random() > 0.5:
51
+ mask.transpose(Image.FLIP_LEFT_RIGHT)
52
+ if np.random.random() > 0.5:
53
+ mask.transpose(Image.FLIP_TOP_BOTTOM)
54
+ mask = np.asarray(mask, np.uint8)
55
+ if np.random.random() > 0.5:
56
+ mask = np.flip(mask, 0)
57
+ if np.random.random() > 0.5:
58
+ mask = np.flip(mask, 1)
59
+ return mask
60
+
61
+ def RandomMask(s, hole_range=[0,1]):
62
+ coef = min(hole_range[0] + hole_range[1], 1.0)
63
+ while True:
64
+ mask = np.ones((s, s), np.uint8)
65
+ def Fill(max_size):
66
+ w, h = np.random.randint(max_size), np.random.randint(max_size)
67
+ ww, hh = w // 2, h // 2
68
+ x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
69
+ mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
70
+ def MultiFill(max_tries, max_size):
71
+ for _ in range(np.random.randint(max_tries)):
72
+ Fill(max_size)
73
+ MultiFill(int(4 * coef), s // 2)
74
+ MultiFill(int(2 * coef), s)
75
+ mask = np.logical_and(mask, 1 - RandomBrush(int(8 * coef), s)) # hole denoted as 0, reserved as 1
76
+ hole_ratio = 1 - np.mean(mask)
77
+ if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
78
+ continue
79
+ return mask[np.newaxis, ...].astype(np.float32)
80
+
81
+ def BatchRandomMask(batch_size, s, hole_range=[0, 1]):
82
+ return np.stack([RandomMask(s, hole_range=hole_range) for _ in range(batch_size)], axis=0)
83
+
84
+
85
+ if __name__ == '__main__':
86
+ # res = 512
87
+ res = 256
88
+ cnt = 2000
89
+ tot = 0
90
+ for i in range(cnt):
91
+ mask = RandomMask(s=res)
92
+ tot += mask.mean()
93
+ print(tot / cnt)
datasets/mask_generator_256_small.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image, ImageDraw
3
+ import math
4
+ import random
5
+
6
+
7
+ def RandomBrush(
8
+ max_tries,
9
+ s,
10
+ min_num_vertex = 4,
11
+ max_num_vertex = 18,
12
+ mean_angle = 2*math.pi / 5,
13
+ angle_range = 2*math.pi / 15,
14
+ min_width = 12,
15
+ max_width = 48):
16
+ H, W = s, s
17
+ average_radius = math.sqrt(H*H+W*W) / 8
18
+ mask = Image.new('L', (W, H), 0)
19
+ for _ in range(np.random.randint(max_tries)):
20
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
21
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
22
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
23
+ angles = []
24
+ vertex = []
25
+ for i in range(num_vertex):
26
+ if i % 2 == 0:
27
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
28
+ else:
29
+ angles.append(np.random.uniform(angle_min, angle_max))
30
+
31
+ h, w = mask.size
32
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
33
+ for i in range(num_vertex):
34
+ r = np.clip(
35
+ np.random.normal(loc=average_radius, scale=average_radius//2),
36
+ 0, 2*average_radius)
37
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
38
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
39
+ vertex.append((int(new_x), int(new_y)))
40
+
41
+ draw = ImageDraw.Draw(mask)
42
+ width = int(np.random.uniform(min_width, max_width))
43
+ draw.line(vertex, fill=1, width=width)
44
+ for v in vertex:
45
+ draw.ellipse((v[0] - width//2,
46
+ v[1] - width//2,
47
+ v[0] + width//2,
48
+ v[1] + width//2),
49
+ fill=1)
50
+ if np.random.random() > 0.5:
51
+ mask.transpose(Image.FLIP_LEFT_RIGHT)
52
+ if np.random.random() > 0.5:
53
+ mask.transpose(Image.FLIP_TOP_BOTTOM)
54
+ mask = np.asarray(mask, np.uint8)
55
+ if np.random.random() > 0.5:
56
+ mask = np.flip(mask, 0)
57
+ if np.random.random() > 0.5:
58
+ mask = np.flip(mask, 1)
59
+ return mask
60
+
61
+ def RandomMask(s, hole_range=[0,1]):
62
+ coef = min(hole_range[0] + hole_range[1], 1.0)
63
+ while True:
64
+ mask = np.ones((s, s), np.uint8)
65
+ def Fill(max_size):
66
+ w, h = np.random.randint(max_size), np.random.randint(max_size)
67
+ ww, hh = w // 2, h // 2
68
+ x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
69
+ mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
70
+ def MultiFill(max_tries, max_size):
71
+ for _ in range(np.random.randint(max_tries)):
72
+ Fill(max_size)
73
+ MultiFill(int(2 * coef), s // 2)
74
+ MultiFill(int(2 * coef), s)
75
+ mask = np.logical_and(mask, 1 - RandomBrush(int(3 * coef), s)) # hole denoted as 0, reserved as 1
76
+ hole_ratio = 1 - np.mean(mask)
77
+ if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
78
+ continue
79
+ return mask[np.newaxis, ...].astype(np.float32)
80
+
81
+ def BatchRandomMask(batch_size, s, hole_range=[0, 1]):
82
+ return np.stack([RandomMask(s, hole_range=hole_range) for _ in range(batch_size)], axis=0)
83
+
84
+
85
+ if __name__ == '__main__':
86
+ # res = 512
87
+ res = 256
88
+ cnt = 2000
89
+ tot = 0
90
+ for i in range(cnt):
91
+ mask = RandomMask(s=res)
92
+ tot += mask.mean()
93
+ print(tot / cnt)
datasets/mask_generator_512.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image, ImageDraw
3
+ import math
4
+ import random
5
+
6
+
7
+ def RandomBrush(
8
+ max_tries,
9
+ s,
10
+ min_num_vertex = 4,
11
+ max_num_vertex = 18,
12
+ mean_angle = 2*math.pi / 5,
13
+ angle_range = 2*math.pi / 15,
14
+ min_width = 12,
15
+ max_width = 48):
16
+ H, W = s, s
17
+ average_radius = math.sqrt(H*H+W*W) / 8
18
+ mask = Image.new('L', (W, H), 0)
19
+ for _ in range(np.random.randint(max_tries)):
20
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
21
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
22
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
23
+ angles = []
24
+ vertex = []
25
+ for i in range(num_vertex):
26
+ if i % 2 == 0:
27
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
28
+ else:
29
+ angles.append(np.random.uniform(angle_min, angle_max))
30
+
31
+ h, w = mask.size
32
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
33
+ for i in range(num_vertex):
34
+ r = np.clip(
35
+ np.random.normal(loc=average_radius, scale=average_radius//2),
36
+ 0, 2*average_radius)
37
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
38
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
39
+ vertex.append((int(new_x), int(new_y)))
40
+
41
+ draw = ImageDraw.Draw(mask)
42
+ width = int(np.random.uniform(min_width, max_width))
43
+ draw.line(vertex, fill=1, width=width)
44
+ for v in vertex:
45
+ draw.ellipse((v[0] - width//2,
46
+ v[1] - width//2,
47
+ v[0] + width//2,
48
+ v[1] + width//2),
49
+ fill=1)
50
+ if np.random.random() > 0.5:
51
+ mask.transpose(Image.FLIP_LEFT_RIGHT)
52
+ if np.random.random() > 0.5:
53
+ mask.transpose(Image.FLIP_TOP_BOTTOM)
54
+ mask = np.asarray(mask, np.uint8)
55
+ if np.random.random() > 0.5:
56
+ mask = np.flip(mask, 0)
57
+ if np.random.random() > 0.5:
58
+ mask = np.flip(mask, 1)
59
+ return mask
60
+
61
+ def RandomMask(s, hole_range=[0,1]):
62
+ coef = min(hole_range[0] + hole_range[1], 1.0)
63
+ while True:
64
+ mask = np.ones((s, s), np.uint8)
65
+ def Fill(max_size):
66
+ w, h = np.random.randint(max_size), np.random.randint(max_size)
67
+ ww, hh = w // 2, h // 2
68
+ x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
69
+ mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
70
+ def MultiFill(max_tries, max_size):
71
+ for _ in range(np.random.randint(max_tries)):
72
+ Fill(max_size)
73
+ MultiFill(int(5 * coef), s // 2)
74
+ MultiFill(int(3 * coef), s)
75
+ mask = np.logical_and(mask, 1 - RandomBrush(int(9 * coef), s)) # hole denoted as 0, reserved as 1
76
+ hole_ratio = 1 - np.mean(mask)
77
+ if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
78
+ continue
79
+ return mask[np.newaxis, ...].astype(np.float32)
80
+
81
+ def BatchRandomMask(batch_size, s, hole_range=[0, 1]):
82
+ return np.stack([RandomMask(s, hole_range=hole_range) for _ in range(batch_size)], axis=0)
83
+
84
+
85
+ if __name__ == '__main__':
86
+ res = 512
87
+ # res = 256
88
+ cnt = 2000
89
+ tot = 0
90
+ for i in range(cnt):
91
+ mask = RandomMask(s=res)
92
+ tot += mask.mean()
93
+ print(tot / cnt)
datasets/mask_generator_512_small.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image, ImageDraw
3
+ import math
4
+ import random
5
+
6
+
7
+ def RandomBrush(
8
+ max_tries,
9
+ s,
10
+ min_num_vertex = 4,
11
+ max_num_vertex = 18,
12
+ mean_angle = 2*math.pi / 5,
13
+ angle_range = 2*math.pi / 15,
14
+ min_width = 12,
15
+ max_width = 48):
16
+ H, W = s, s
17
+ average_radius = math.sqrt(H*H+W*W) / 8
18
+ mask = Image.new('L', (W, H), 0)
19
+ for _ in range(np.random.randint(max_tries)):
20
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
21
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
22
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
23
+ angles = []
24
+ vertex = []
25
+ for i in range(num_vertex):
26
+ if i % 2 == 0:
27
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
28
+ else:
29
+ angles.append(np.random.uniform(angle_min, angle_max))
30
+
31
+ h, w = mask.size
32
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
33
+ for i in range(num_vertex):
34
+ r = np.clip(
35
+ np.random.normal(loc=average_radius, scale=average_radius//2),
36
+ 0, 2*average_radius)
37
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
38
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
39
+ vertex.append((int(new_x), int(new_y)))
40
+
41
+ draw = ImageDraw.Draw(mask)
42
+ width = int(np.random.uniform(min_width, max_width))
43
+ draw.line(vertex, fill=1, width=width)
44
+ for v in vertex:
45
+ draw.ellipse((v[0] - width//2,
46
+ v[1] - width//2,
47
+ v[0] + width//2,
48
+ v[1] + width//2),
49
+ fill=1)
50
+ if np.random.random() > 0.5:
51
+ mask.transpose(Image.FLIP_LEFT_RIGHT)
52
+ if np.random.random() > 0.5:
53
+ mask.transpose(Image.FLIP_TOP_BOTTOM)
54
+ mask = np.asarray(mask, np.uint8)
55
+ if np.random.random() > 0.5:
56
+ mask = np.flip(mask, 0)
57
+ if np.random.random() > 0.5:
58
+ mask = np.flip(mask, 1)
59
+ return mask
60
+
61
+ def RandomMask(s, hole_range=[0,1]):
62
+ coef = min(hole_range[0] + hole_range[1], 1.0)
63
+ while True:
64
+ mask = np.ones((s, s), np.uint8)
65
+ def Fill(max_size):
66
+ w, h = np.random.randint(max_size), np.random.randint(max_size)
67
+ ww, hh = w // 2, h // 2
68
+ x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
69
+ mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
70
+ def MultiFill(max_tries, max_size):
71
+ for _ in range(np.random.randint(max_tries)):
72
+ Fill(max_size)
73
+ MultiFill(int(3 * coef), s // 2)
74
+ MultiFill(int(2 * coef), s)
75
+ mask = np.logical_and(mask, 1 - RandomBrush(int(4 * coef), s)) # hole denoted as 0, reserved as 1
76
+ hole_ratio = 1 - np.mean(mask)
77
+ if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
78
+ continue
79
+ return mask[np.newaxis, ...].astype(np.float32)
80
+
81
+ def BatchRandomMask(batch_size, s, hole_range=[0, 1]):
82
+ return np.stack([RandomMask(s, hole_range=hole_range) for _ in range(batch_size)], axis=0)
83
+
84
+
85
+ if __name__ == '__main__':
86
+ res = 512
87
+ # res = 256
88
+ cnt = 2000
89
+ tot = 0
90
+ for i in range(cnt):
91
+ mask = RandomMask(s=res)
92
+ tot += mask.mean()
93
+ print(tot / cnt)
dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
dnnlib/util.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def ask_yes_no(question: str) -> bool:
154
+ """Ask the user the question until the user inputs a valid answer."""
155
+ while True:
156
+ try:
157
+ print("{0} [y/n]".format(question))
158
+ return strtobool(input().lower())
159
+ except ValueError:
160
+ pass
161
+
162
+
163
+ def tuple_product(t: Tuple) -> Any:
164
+ """Calculate the product of the tuple elements."""
165
+ result = 1
166
+
167
+ for v in t:
168
+ result *= v
169
+
170
+ return result
171
+
172
+
173
+ _str_to_ctype = {
174
+ "uint8": ctypes.c_ubyte,
175
+ "uint16": ctypes.c_uint16,
176
+ "uint32": ctypes.c_uint32,
177
+ "uint64": ctypes.c_uint64,
178
+ "int8": ctypes.c_byte,
179
+ "int16": ctypes.c_int16,
180
+ "int32": ctypes.c_int32,
181
+ "int64": ctypes.c_int64,
182
+ "float32": ctypes.c_float,
183
+ "float64": ctypes.c_double
184
+ }
185
+
186
+
187
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
188
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
189
+ type_str = None
190
+
191
+ if isinstance(type_obj, str):
192
+ type_str = type_obj
193
+ elif hasattr(type_obj, "__name__"):
194
+ type_str = type_obj.__name__
195
+ elif hasattr(type_obj, "name"):
196
+ type_str = type_obj.name
197
+ else:
198
+ raise RuntimeError("Cannot infer type name from input")
199
+
200
+ assert type_str in _str_to_ctype.keys()
201
+
202
+ my_dtype = np.dtype(type_str)
203
+ my_ctype = _str_to_ctype[type_str]
204
+
205
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
206
+
207
+ return my_dtype, my_ctype
208
+
209
+
210
+ def is_pickleable(obj: Any) -> bool:
211
+ try:
212
+ with io.BytesIO() as stream:
213
+ pickle.dump(obj, stream)
214
+ return True
215
+ except:
216
+ return False
217
+
218
+
219
+ # Functionality to import modules/objects by name, and call functions by name
220
+ # ------------------------------------------------------------------------------------------
221
+
222
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
223
+ """Searches for the underlying module behind the name to some python object.
224
+ Returns the module and the object name (original name with module part removed)."""
225
+
226
+ # allow convenience shorthands, substitute them by full names
227
+ obj_name = re.sub("^np.", "numpy.", obj_name)
228
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
229
+
230
+ # list alternatives for (module_name, local_obj_name)
231
+ parts = obj_name.split(".")
232
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
233
+
234
+ # try each alternative in turn
235
+ for module_name, local_obj_name in name_pairs:
236
+ try:
237
+ module = importlib.import_module(module_name) # may raise ImportError
238
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
239
+ return module, local_obj_name
240
+ except:
241
+ pass
242
+
243
+ # maybe some of the modules themselves contain errors?
244
+ for module_name, _local_obj_name in name_pairs:
245
+ try:
246
+ importlib.import_module(module_name) # may raise ImportError
247
+ except ImportError:
248
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
249
+ raise
250
+
251
+ # maybe the requested attribute is missing?
252
+ for module_name, local_obj_name in name_pairs:
253
+ try:
254
+ module = importlib.import_module(module_name) # may raise ImportError
255
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
256
+ except ImportError:
257
+ pass
258
+
259
+ # we are out of luck, but we have no idea why
260
+ raise ImportError(obj_name)
261
+
262
+
263
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
264
+ """Traverses the object name and returns the last (rightmost) python object."""
265
+ if obj_name == '':
266
+ return module
267
+ obj = module
268
+ for part in obj_name.split("."):
269
+ obj = getattr(obj, part)
270
+ return obj
271
+
272
+
273
+ def get_obj_by_name(name: str) -> Any:
274
+ """Finds the python object with the given name."""
275
+ module, obj_name = get_module_from_obj_name(name)
276
+ return get_obj_from_module(module, obj_name)
277
+
278
+
279
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
280
+ """Finds the python object with the given name and calls it as a function."""
281
+ assert func_name is not None
282
+ func_obj = get_obj_by_name(func_name)
283
+ assert callable(func_obj)
284
+ return func_obj(*args, **kwargs)
285
+
286
+
287
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
288
+ """Finds the python class with the given name and constructs it with the given arguments."""
289
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
290
+
291
+
292
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
293
+ """Get the directory path of the module containing the given object name."""
294
+ module, _ = get_module_from_obj_name(obj_name)
295
+ return os.path.dirname(inspect.getfile(module))
296
+
297
+
298
+ def is_top_level_function(obj: Any) -> bool:
299
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
300
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
301
+
302
+
303
+ def get_top_level_function_name(obj: Any) -> str:
304
+ """Return the fully-qualified name of a top-level function."""
305
+ assert is_top_level_function(obj)
306
+ module = obj.__module__
307
+ if module == '__main__':
308
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
309
+ return module + "." + obj.__name__
310
+
311
+
312
+ # File system helpers
313
+ # ------------------------------------------------------------------------------------------
314
+
315
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
316
+ """List all files recursively in a given directory while ignoring given file and directory names.
317
+ Returns list of tuples containing both absolute and relative paths."""
318
+ assert os.path.isdir(dir_path)
319
+ base_name = os.path.basename(os.path.normpath(dir_path))
320
+
321
+ if ignores is None:
322
+ ignores = []
323
+
324
+ result = []
325
+
326
+ for root, dirs, files in os.walk(dir_path, topdown=True):
327
+ for ignore_ in ignores:
328
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
329
+
330
+ # dirs need to be edited in-place
331
+ for d in dirs_to_remove:
332
+ dirs.remove(d)
333
+
334
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
335
+
336
+ absolute_paths = [os.path.join(root, f) for f in files]
337
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
338
+
339
+ if add_base_to_relative:
340
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
341
+
342
+ assert len(absolute_paths) == len(relative_paths)
343
+ result += zip(absolute_paths, relative_paths)
344
+
345
+ return result
346
+
347
+
348
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
349
+ """Takes in a list of tuples of (src, dst) paths and copies files.
350
+ Will create all necessary directories."""
351
+ for file in files:
352
+ target_dir_name = os.path.dirname(file[1])
353
+
354
+ # will create all intermediate-level directories
355
+ if not os.path.exists(target_dir_name):
356
+ os.makedirs(target_dir_name)
357
+
358
+ shutil.copyfile(file[0], file[1])
359
+
360
+
361
+ # URL helpers
362
+ # ------------------------------------------------------------------------------------------
363
+
364
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
365
+ """Determine whether the given object is a valid URL string."""
366
+ if not isinstance(obj, str) or not "://" in obj:
367
+ return False
368
+ if allow_file_urls and obj.startswith('file://'):
369
+ return True
370
+ try:
371
+ res = requests.compat.urlparse(obj)
372
+ if not res.scheme or not res.netloc or not "." in res.netloc:
373
+ return False
374
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
375
+ if not res.scheme or not res.netloc or not "." in res.netloc:
376
+ return False
377
+ except:
378
+ return False
379
+ return True
380
+
381
+
382
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
383
+ """Download the given URL and return a binary-mode file object to access the data."""
384
+ assert num_attempts >= 1
385
+ assert not (return_filename and (not cache))
386
+
387
+ # Doesn't look like an URL scheme so interpret it as a local filename.
388
+ if not re.match('^[a-z]+://', url):
389
+ return url if return_filename else open(url, "rb")
390
+
391
+ # Handle file URLs. This code handles unusual file:// patterns that
392
+ # arise on Windows:
393
+ #
394
+ # file:///c:/foo.txt
395
+ #
396
+ # which would translate to a local '/c:/foo.txt' filename that's
397
+ # invalid. Drop the forward slash for such pathnames.
398
+ #
399
+ # If you touch this code path, you should test it on both Linux and
400
+ # Windows.
401
+ #
402
+ # Some internet resources suggest using urllib.request.url2pathname() but
403
+ # but that converts forward slashes to backslashes and this causes
404
+ # its own set of problems.
405
+ if url.startswith('file://'):
406
+ filename = urllib.parse.urlparse(url).path
407
+ if re.match(r'^/[a-zA-Z]:', filename):
408
+ filename = filename[1:]
409
+ return filename if return_filename else open(filename, "rb")
410
+
411
+ assert is_url(url)
412
+
413
+ # Lookup from cache.
414
+ if cache_dir is None:
415
+ cache_dir = make_cache_dir_path('downloads')
416
+
417
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
418
+ if cache:
419
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
420
+ if len(cache_files) == 1:
421
+ filename = cache_files[0]
422
+ return filename if return_filename else open(filename, "rb")
423
+
424
+ # Download.
425
+ url_name = None
426
+ url_data = None
427
+ with requests.Session() as session:
428
+ if verbose:
429
+ print("Downloading %s ..." % url, end="", flush=True)
430
+ for attempts_left in reversed(range(num_attempts)):
431
+ try:
432
+ with session.get(url) as res:
433
+ res.raise_for_status()
434
+ if len(res.content) == 0:
435
+ raise IOError("No data received")
436
+
437
+ if len(res.content) < 8192:
438
+ content_str = res.content.decode("utf-8")
439
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
440
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
441
+ if len(links) == 1:
442
+ url = requests.compat.urljoin(url, links[0])
443
+ raise IOError("Google Drive virus checker nag")
444
+ if "Google Drive - Quota exceeded" in content_str:
445
+ raise IOError("Google Drive download quota exceeded -- please try again later")
446
+
447
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
448
+ url_name = match[1] if match else url
449
+ url_data = res.content
450
+ if verbose:
451
+ print(" done")
452
+ break
453
+ except KeyboardInterrupt:
454
+ raise
455
+ except:
456
+ if not attempts_left:
457
+ if verbose:
458
+ print(" failed")
459
+ raise
460
+ if verbose:
461
+ print(".", end="", flush=True)
462
+
463
+ # Save to cache.
464
+ if cache:
465
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
466
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
467
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
468
+ os.makedirs(cache_dir, exist_ok=True)
469
+ with open(temp_file, "wb") as f:
470
+ f.write(url_data)
471
+ os.replace(temp_file, cache_file) # atomic
472
+ if return_filename:
473
+ return cache_file
474
+
475
+ # Return data as file object.
476
+ assert not return_filename
477
+ return io.BytesIO(url_data)
evaluatoin/cal_fid_pids_uids.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import sys
4
+ sys.path.insert(0, '../')
5
+ import numpy as np
6
+ import math
7
+ import glob
8
+ import pyspng
9
+ import PIL.Image
10
+ import torch
11
+ import dnnlib
12
+ import scipy.linalg
13
+ import sklearn.svm
14
+
15
+
16
+ _feature_detector_cache = dict()
17
+
18
+ def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
19
+ assert 0 <= rank < num_gpus
20
+ key = (url, device)
21
+ if key not in _feature_detector_cache:
22
+ is_leader = (rank == 0)
23
+ if not is_leader and num_gpus > 1:
24
+ torch.distributed.barrier() # leader goes first
25
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
26
+ _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
27
+ if is_leader and num_gpus > 1:
28
+ torch.distributed.barrier() # others follow
29
+ return _feature_detector_cache[key]
30
+
31
+
32
+ def read_image(image_path):
33
+ with open(image_path, 'rb') as f:
34
+ if pyspng is not None and image_path.endswith('.png'):
35
+ image = pyspng.load(f.read())
36
+ else:
37
+ image = np.array(PIL.Image.open(f))
38
+ if image.ndim == 2:
39
+ image = image[:, :, np.newaxis] # HW => HWC
40
+ if image.shape[2] == 1:
41
+ image = np.repeat(image, 3, axis=2)
42
+ image = image.transpose(2, 0, 1) # HWC => CHW
43
+ image = torch.from_numpy(image).unsqueeze(0).to(torch.uint8)
44
+
45
+ return image
46
+
47
+
48
+ class FeatureStats:
49
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
50
+ self.capture_all = capture_all
51
+ self.capture_mean_cov = capture_mean_cov
52
+ self.max_items = max_items
53
+ self.num_items = 0
54
+ self.num_features = None
55
+ self.all_features = None
56
+ self.raw_mean = None
57
+ self.raw_cov = None
58
+
59
+ def set_num_features(self, num_features):
60
+ if self.num_features is not None:
61
+ assert num_features == self.num_features
62
+ else:
63
+ self.num_features = num_features
64
+ self.all_features = []
65
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
66
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
67
+
68
+ def is_full(self):
69
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
70
+
71
+ def append(self, x):
72
+ x = np.asarray(x, dtype=np.float32)
73
+ assert x.ndim == 2
74
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
75
+ if self.num_items >= self.max_items:
76
+ return
77
+ x = x[:self.max_items - self.num_items]
78
+
79
+ self.set_num_features(x.shape[1])
80
+ self.num_items += x.shape[0]
81
+ if self.capture_all:
82
+ self.all_features.append(x)
83
+ if self.capture_mean_cov:
84
+ x64 = x.astype(np.float64)
85
+ self.raw_mean += x64.sum(axis=0)
86
+ self.raw_cov += x64.T @ x64
87
+
88
+ def append_torch(self, x, num_gpus=1, rank=0):
89
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
90
+ assert 0 <= rank < num_gpus
91
+ if num_gpus > 1:
92
+ ys = []
93
+ for src in range(num_gpus):
94
+ y = x.clone()
95
+ torch.distributed.broadcast(y, src=src)
96
+ ys.append(y)
97
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
98
+ self.append(x.cpu().numpy())
99
+
100
+ def get_all(self):
101
+ assert self.capture_all
102
+ return np.concatenate(self.all_features, axis=0)
103
+
104
+ def get_all_torch(self):
105
+ return torch.from_numpy(self.get_all())
106
+
107
+ def get_mean_cov(self):
108
+ assert self.capture_mean_cov
109
+ mean = self.raw_mean / self.num_items
110
+ cov = self.raw_cov / self.num_items
111
+ cov = cov - np.outer(mean, mean)
112
+ return mean, cov
113
+
114
+ def save(self, pkl_file):
115
+ with open(pkl_file, 'wb') as f:
116
+ pickle.dump(self.__dict__, f)
117
+
118
+ @staticmethod
119
+ def load(pkl_file):
120
+ with open(pkl_file, 'rb') as f:
121
+ s = dnnlib.EasyDict(pickle.load(f))
122
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
123
+ obj.__dict__.update(s)
124
+ return obj
125
+
126
+
127
+ def calculate_metrics(folder1, folder2):
128
+ l1 = sorted(glob.glob(folder1 + '/*.png') + glob.glob(folder1 + '/*.jpg'))
129
+ l2 = sorted(glob.glob(folder2 + '/*.png') + glob.glob(folder2 + '/*.jpg'))
130
+ assert(len(l1) == len(l2))
131
+ print('length:', len(l1))
132
+
133
+ # l1 = l1[:3]; l2 = l2[:3];
134
+
135
+ # build detector
136
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
137
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
138
+ device = torch.device('cuda:0')
139
+ detector = get_feature_detector(url=detector_url, device=device, num_gpus=1, rank=0, verbose=False)
140
+ detector.eval()
141
+
142
+ stat1 = FeatureStats(capture_all=True, capture_mean_cov=True, max_items=len(l1))
143
+ stat2 = FeatureStats(capture_all=True, capture_mean_cov=True, max_items=len(l1))
144
+
145
+ with torch.no_grad():
146
+ for i, (fpath1, fpath2) in enumerate(zip(l1, l2)):
147
+ print(i)
148
+ _, name1 = os.path.split(fpath1)
149
+ _, name2 = os.path.split(fpath2)
150
+ name1 = name1.split('.')[0]
151
+ name2 = name2.split('.')[0]
152
+ assert name1 == name2, 'Illegal mapping: %s, %s' % (name1, name2)
153
+
154
+ img1 = read_image(fpath1).to(device)
155
+ img2 = read_image(fpath2).to(device)
156
+ assert img1.shape == img2.shape, 'Illegal shape'
157
+ fea1 = detector(img1, **detector_kwargs)
158
+ stat1.append_torch(fea1, num_gpus=1, rank=0)
159
+ fea2 = detector(img2, **detector_kwargs)
160
+ stat2.append_torch(fea2, num_gpus=1, rank=0)
161
+
162
+ # calculate fid
163
+ mu1, sigma1 = stat1.get_mean_cov()
164
+ mu2, sigma2 = stat2.get_mean_cov()
165
+ m = np.square(mu1 - mu2).sum()
166
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma1, sigma2), disp=False) # pylint: disable=no-member
167
+ fid = np.real(m + np.trace(sigma1 + sigma2 - s * 2))
168
+
169
+ # calculate pids and uids
170
+ fake_activations = stat1.get_all()
171
+ real_activations = stat2.get_all()
172
+ svm = sklearn.svm.LinearSVC(dual=False)
173
+ svm_inputs = np.concatenate([real_activations, fake_activations])
174
+ svm_targets = np.array([1] * real_activations.shape[0] + [0] * fake_activations.shape[0])
175
+ print('SVM fitting ...')
176
+ svm.fit(svm_inputs, svm_targets)
177
+ uids = 1 - svm.score(svm_inputs, svm_targets)
178
+ real_outputs = svm.decision_function(real_activations)
179
+ fake_outputs = svm.decision_function(fake_activations)
180
+ pids = np.mean(fake_outputs > real_outputs)
181
+
182
+ return fid, pids, uids
183
+
184
+
185
+ if __name__ == '__main__':
186
+ folder1 = 'path to the inpainted result'
187
+ folder2 = 'path to the gt'
188
+
189
+ fid, pids, uids = calculate_metrics(folder1, folder2)
190
+ print('fid: %.4f, pids: %.4f, uids: %.4f' % (fid, pids, uids))
191
+ with open('fid_pids_uids.txt', 'w') as f:
192
+ f.write('fid: %.4f, pids: %.4f, uids: %.4f' % (fid, pids, uids))
193
+
evaluatoin/cal_lpips.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import sys
4
+ import numpy as np
5
+ import math
6
+ import glob
7
+ import pyspng
8
+ import PIL.Image
9
+
10
+ import torch
11
+ import lpips
12
+
13
+
14
+ def read_image(image_path):
15
+ with open(image_path, 'rb') as f:
16
+ if pyspng is not None and image_path.endswith('.png'):
17
+ image = pyspng.load(f.read())
18
+ else:
19
+ image = np.array(PIL.Image.open(f))
20
+ if image.ndim == 2:
21
+ image = image[:, :, np.newaxis] # HW => HWC
22
+ if image.shape[2] == 1:
23
+ image = np.repeat(image, 3, axis=2)
24
+ image = image.transpose(2, 0, 1) # HWC => CHW
25
+ image = torch.from_numpy(image).float().unsqueeze(0)
26
+ image = image / 127.5 - 1
27
+
28
+ return image
29
+
30
+
31
+ def calculate_metrics(folder1, folder2):
32
+ l1 = sorted(glob.glob(folder1 + '/*.png') + glob.glob(folder1 + '/*.jpg'))
33
+ l2 = sorted(glob.glob(folder2 + '/*.png') + glob.glob(folder2 + '/*.jpg'))
34
+ assert(len(l1) == len(l2))
35
+ print('length:', len(l1))
36
+
37
+ # l1 = l1[:3]; l2 = l2[:3];
38
+
39
+ device = torch.device('cuda:0')
40
+ loss_fn = lpips.LPIPS(net='alex').to(device)
41
+ loss_fn.eval()
42
+ # loss_fn = lpips.LPIPS(net='vgg').to(device)
43
+
44
+ lpips_l = []
45
+ with torch.no_grad():
46
+ for i, (fpath1, fpath2) in enumerate(zip(l1, l2)):
47
+ print(i)
48
+ _, name1 = os.path.split(fpath1)
49
+ _, name2 = os.path.split(fpath2)
50
+ name1 = name1.split('.')[0]
51
+ name2 = name2.split('.')[0]
52
+ assert name1 == name2, 'Illegal mapping: %s, %s' % (name1, name2)
53
+
54
+ img1 = read_image(fpath1).to(device)
55
+ img2 = read_image(fpath2).to(device)
56
+ assert img1.shape == img2.shape, 'Illegal shape'
57
+ lpips_l.append(loss_fn(img1, img2).mean().cpu().numpy())
58
+
59
+ res = sum(lpips_l) / len(lpips_l)
60
+
61
+ return res
62
+
63
+
64
+ if __name__ == '__main__':
65
+ folder1 = 'path to the inpainted result'
66
+ folder2 = 'path to the gt'
67
+
68
+ res = calculate_metrics(folder1, folder2)
69
+ print('lpips: %.4f' % res)
70
+ with open('lpips.txt', 'w') as f:
71
+ f.write('lpips: %.4f' % res)
evaluatoin/cal_psnr_ssim_l1.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import sys
4
+ import numpy as np
5
+ import math
6
+ import glob
7
+ import pyspng
8
+ import PIL.Image
9
+
10
+
11
+ def calculate_psnr(img1, img2):
12
+ # img1 and img2 have range [0, 255]
13
+ img1 = img1.astype(np.float64)
14
+ img2 = img2.astype(np.float64)
15
+ mse = np.mean((img1 - img2) ** 2)
16
+ if mse == 0:
17
+ return float('inf')
18
+
19
+ return 20 * math.log10(255.0 / math.sqrt(mse))
20
+
21
+
22
+ def calculate_ssim(img1, img2):
23
+ C1 = (0.01 * 255) ** 2
24
+ C2 = (0.03 * 255) ** 2
25
+
26
+ img1 = img1.astype(np.float64)
27
+ img2 = img2.astype(np.float64)
28
+ kernel = cv2.getGaussianKernel(11, 1.5)
29
+ window = np.outer(kernel, kernel.transpose())
30
+
31
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
32
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
33
+ mu1_sq = mu1 ** 2
34
+ mu2_sq = mu2 ** 2
35
+ mu1_mu2 = mu1 * mu2
36
+ sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
37
+ sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
38
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
39
+
40
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
41
+
42
+ return ssim_map.mean()
43
+
44
+
45
+ def calculate_l1(img1, img2):
46
+ img1 = img1.astype(np.float64) / 255.0
47
+ img2 = img2.astype(np.float64) / 255.0
48
+ l1 = np.mean(np.abs(img1 - img2))
49
+
50
+ return l1
51
+
52
+
53
+ def read_image(image_path):
54
+ with open(image_path, 'rb') as f:
55
+ if pyspng is not None and image_path.endswith('.png'):
56
+ image = pyspng.load(f.read())
57
+ else:
58
+ image = np.array(PIL.Image.open(f))
59
+ if image.ndim == 2:
60
+ image = image[:, :, np.newaxis] # HW => HWC
61
+ if image.shape[2] == 1:
62
+ image = np.repeat(image, 3, axis=2)
63
+ # image = image.transpose(2, 0, 1) # HWC => CHW
64
+
65
+ return image
66
+
67
+
68
+ def calculate_metrics(folder1, folder2):
69
+ l1 = sorted(glob.glob(folder1 + '/*.png') + glob.glob(folder1 + '/*.jpg'))
70
+ l2 = sorted(glob.glob(folder2 + '/*.png') + glob.glob(folder2 + '/*.jpg'))
71
+ assert(len(l1) == len(l2))
72
+ print('length:', len(l1))
73
+
74
+ # l1 = l1[:3]; l2 = l2[:3];
75
+
76
+ psnr_l, ssim_l, dl1_l = [], [], []
77
+ for i, (fpath1, fpath2) in enumerate(zip(l1, l2)):
78
+ print(i)
79
+ _, name1 = os.path.split(fpath1)
80
+ _, name2 = os.path.split(fpath2)
81
+ name1 = name1.split('.')[0]
82
+ name2 = name2.split('.')[0]
83
+ assert name1 == name2, 'Illegal mapping: %s, %s' % (name1, name2)
84
+
85
+ img1 = read_image(fpath1).astype(np.float64)
86
+ img2 = read_image(fpath2).astype(np.float64)
87
+ assert img1.shape == img2.shape, 'Illegal shape'
88
+ psnr_l.append(calculate_psnr(img1, img2))
89
+ ssim_l.append(calculate_ssim(img1, img2))
90
+ dl1_l.append(calculate_l1(img1, img2))
91
+
92
+ psnr = sum(psnr_l) / len(psnr_l)
93
+ ssim = sum(ssim_l) / len(ssim_l)
94
+ dl1 = sum(dl1_l) / len(dl1_l)
95
+
96
+ return psnr, ssim, dl1
97
+
98
+
99
+ if __name__ == '__main__':
100
+ folder1 = 'path to the inpainted result'
101
+ folder2 = 'path to the gt'
102
+
103
+ psnr, ssim, dl1 = calculate_metrics(folder1, folder2)
104
+ print('psnr: %.4f, ssim: %.4f, l1: %.4f' % (psnr, ssim, dl1))
105
+ with open('psnr_ssim_l1.txt', 'w') as f:
106
+ f.write('psnr: %.4f, ssim: %.4f, l1: %.4f' % (psnr, ssim, dl1))
107
+
legacy.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import click
10
+ import pickle
11
+ import re
12
+ import copy
13
+ import numpy as np
14
+ import torch
15
+ import dnnlib
16
+ from torch_utils import misc
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def load_network_pkl(f, force_fp16=False):
21
+ data = _LegacyUnpickler(f).load()
22
+
23
+ # Legacy TensorFlow pickle => convert.
24
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
25
+ tf_G, tf_D, tf_Gs = data
26
+ G = convert_tf_generator(tf_G)
27
+ D = convert_tf_discriminator(tf_D)
28
+ G_ema = convert_tf_generator(tf_Gs)
29
+ data = dict(G=G, D=D, G_ema=G_ema)
30
+
31
+ # Add missing fields.
32
+ if 'training_set_kwargs' not in data:
33
+ data['training_set_kwargs'] = None
34
+ if 'augment_pipe' not in data:
35
+ data['augment_pipe'] = None
36
+
37
+ # Validate contents.
38
+ assert isinstance(data['G'], torch.nn.Module)
39
+ assert isinstance(data['D'], torch.nn.Module)
40
+ assert isinstance(data['G_ema'], torch.nn.Module)
41
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
42
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
43
+
44
+ # Force FP16.
45
+ if force_fp16:
46
+ for key in ['G', 'D', 'G_ema']:
47
+ old = data[key]
48
+ kwargs = copy.deepcopy(old.init_kwargs)
49
+ if key.startswith('G'):
50
+ kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
51
+ kwargs.synthesis_kwargs.num_fp16_res = 4
52
+ kwargs.synthesis_kwargs.conv_clamp = 256
53
+ if key.startswith('D'):
54
+ kwargs.num_fp16_res = 4
55
+ kwargs.conv_clamp = 256
56
+ if kwargs != old.init_kwargs:
57
+ new = type(old)(**kwargs).eval().requires_grad_(False)
58
+ misc.copy_params_and_buffers(old, new, require_all=True)
59
+ data[key] = new
60
+ return data
61
+
62
+ #----------------------------------------------------------------------------
63
+
64
+ class _TFNetworkStub(dnnlib.EasyDict):
65
+ pass
66
+
67
+ class _LegacyUnpickler(pickle.Unpickler):
68
+ def find_class(self, module, name):
69
+ if module == 'torch.storage' and name == '_load_from_bytes':
70
+ import io
71
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
72
+ if module == 'dnnlib.tflib.network' and name == 'Network':
73
+ return _TFNetworkStub
74
+ return super().find_class(module, name)
75
+
76
+ #----------------------------------------------------------------------------
77
+
78
+ def _collect_tf_params(tf_net):
79
+ # pylint: disable=protected-access
80
+ tf_params = dict()
81
+ def recurse(prefix, tf_net):
82
+ for name, value in tf_net.variables:
83
+ tf_params[prefix + name] = value
84
+ for name, comp in tf_net.components.items():
85
+ recurse(prefix + name + '/', comp)
86
+ recurse('', tf_net)
87
+ return tf_params
88
+
89
+ #----------------------------------------------------------------------------
90
+
91
+ def _populate_module_params(module, *patterns):
92
+ for name, tensor in misc.named_params_and_buffers(module):
93
+ found = False
94
+ value = None
95
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
96
+ match = re.fullmatch(pattern, name)
97
+ if match:
98
+ found = True
99
+ if value_fn is not None:
100
+ value = value_fn(*match.groups())
101
+ break
102
+ try:
103
+ assert found
104
+ if value is not None:
105
+ tensor.copy_(torch.from_numpy(np.array(value)))
106
+ except:
107
+ print(name, list(tensor.shape))
108
+ raise
109
+
110
+ #----------------------------------------------------------------------------
111
+
112
+ def convert_tf_generator(tf_G):
113
+ if tf_G.version < 4:
114
+ raise ValueError('TensorFlow pickle version too low')
115
+
116
+ # Collect kwargs.
117
+ tf_kwargs = tf_G.static_kwargs
118
+ known_kwargs = set()
119
+ def kwarg(tf_name, default=None, none=None):
120
+ known_kwargs.add(tf_name)
121
+ val = tf_kwargs.get(tf_name, default)
122
+ return val if val is not None else none
123
+
124
+ # Convert kwargs.
125
+ kwargs = dnnlib.EasyDict(
126
+ z_dim = kwarg('latent_size', 512),
127
+ c_dim = kwarg('label_size', 0),
128
+ w_dim = kwarg('dlatent_size', 512),
129
+ img_resolution = kwarg('resolution', 1024),
130
+ img_channels = kwarg('num_channels', 3),
131
+ mapping_kwargs = dnnlib.EasyDict(
132
+ num_layers = kwarg('mapping_layers', 8),
133
+ embed_features = kwarg('label_fmaps', None),
134
+ layer_features = kwarg('mapping_fmaps', None),
135
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
136
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
137
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
138
+ ),
139
+ synthesis_kwargs = dnnlib.EasyDict(
140
+ channel_base = kwarg('fmap_base', 16384) * 2,
141
+ channel_max = kwarg('fmap_max', 512),
142
+ num_fp16_res = kwarg('num_fp16_res', 0),
143
+ conv_clamp = kwarg('conv_clamp', None),
144
+ architecture = kwarg('architecture', 'skip'),
145
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
146
+ use_noise = kwarg('use_noise', True),
147
+ activation = kwarg('nonlinearity', 'lrelu'),
148
+ ),
149
+ )
150
+
151
+ # Check for unknown kwargs.
152
+ kwarg('truncation_psi')
153
+ kwarg('truncation_cutoff')
154
+ kwarg('style_mixing_prob')
155
+ kwarg('structure')
156
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
157
+ if len(unknown_kwargs) > 0:
158
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
159
+
160
+ # Collect params.
161
+ tf_params = _collect_tf_params(tf_G)
162
+ for name, value in list(tf_params.items()):
163
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
164
+ if match:
165
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
166
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
167
+ kwargs.synthesis.kwargs.architecture = 'orig'
168
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
169
+
170
+ # Convert params.
171
+ from training import networks
172
+ G = networks.Generator(**kwargs).eval().requires_grad_(False)
173
+ # pylint: disable=unnecessary-lambda
174
+ _populate_module_params(G,
175
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
176
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
177
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
178
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
179
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
180
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
181
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
182
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
183
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
184
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
185
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
186
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
187
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
188
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
189
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
190
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
191
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
192
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
193
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
194
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
195
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
196
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
197
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
198
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
199
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
200
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
201
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
202
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
203
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
204
+ r'.*\.resample_filter', None,
205
+ )
206
+ return G
207
+
208
+ #----------------------------------------------------------------------------
209
+
210
+ def convert_tf_discriminator(tf_D):
211
+ if tf_D.version < 4:
212
+ raise ValueError('TensorFlow pickle version too low')
213
+
214
+ # Collect kwargs.
215
+ tf_kwargs = tf_D.static_kwargs
216
+ known_kwargs = set()
217
+ def kwarg(tf_name, default=None):
218
+ known_kwargs.add(tf_name)
219
+ return tf_kwargs.get(tf_name, default)
220
+
221
+ # Convert kwargs.
222
+ kwargs = dnnlib.EasyDict(
223
+ c_dim = kwarg('label_size', 0),
224
+ img_resolution = kwarg('resolution', 1024),
225
+ img_channels = kwarg('num_channels', 3),
226
+ architecture = kwarg('architecture', 'resnet'),
227
+ channel_base = kwarg('fmap_base', 16384) * 2,
228
+ channel_max = kwarg('fmap_max', 512),
229
+ num_fp16_res = kwarg('num_fp16_res', 0),
230
+ conv_clamp = kwarg('conv_clamp', None),
231
+ cmap_dim = kwarg('mapping_fmaps', None),
232
+ block_kwargs = dnnlib.EasyDict(
233
+ activation = kwarg('nonlinearity', 'lrelu'),
234
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
235
+ freeze_layers = kwarg('freeze_layers', 0),
236
+ ),
237
+ mapping_kwargs = dnnlib.EasyDict(
238
+ num_layers = kwarg('mapping_layers', 0),
239
+ embed_features = kwarg('mapping_fmaps', None),
240
+ layer_features = kwarg('mapping_fmaps', None),
241
+ activation = kwarg('nonlinearity', 'lrelu'),
242
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
243
+ ),
244
+ epilogue_kwargs = dnnlib.EasyDict(
245
+ mbstd_group_size = kwarg('mbstd_group_size', None),
246
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
247
+ activation = kwarg('nonlinearity', 'lrelu'),
248
+ ),
249
+ )
250
+
251
+ # Check for unknown kwargs.
252
+ kwarg('structure')
253
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
254
+ if len(unknown_kwargs) > 0:
255
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
256
+
257
+ # Collect params.
258
+ tf_params = _collect_tf_params(tf_D)
259
+ for name, value in list(tf_params.items()):
260
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
261
+ if match:
262
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
263
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
264
+ kwargs.architecture = 'orig'
265
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
266
+
267
+ # Convert params.
268
+ from training import networks
269
+ D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
270
+ # pylint: disable=unnecessary-lambda
271
+ _populate_module_params(D,
272
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
273
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
274
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
275
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
276
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
277
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
278
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
279
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
280
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
281
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
282
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
283
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
284
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
285
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
286
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
287
+ r'.*\.resample_filter', None,
288
+ )
289
+ return D
290
+
291
+ #----------------------------------------------------------------------------
292
+
293
+ @click.command()
294
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
295
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
296
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
297
+ def convert_network_pickle(source, dest, force_fp16):
298
+ """Convert legacy network pickle into the native PyTorch format.
299
+
300
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
301
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
302
+
303
+ Example:
304
+
305
+ \b
306
+ python legacy.py \\
307
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
308
+ --dest=stylegan2-cat-config-f.pkl
309
+ """
310
+ print(f'Loading "{source}"...')
311
+ with dnnlib.util.open_url(source) as f:
312
+ data = load_network_pkl(f, force_fp16=force_fp16)
313
+ print(f'Saving "{dest}"...')
314
+ with open(dest, 'wb') as f:
315
+ pickle.dump(data, f)
316
+ print('Done.')
317
+
318
+ #----------------------------------------------------------------------------
319
+
320
+ if __name__ == "__main__":
321
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
322
+
323
+ #----------------------------------------------------------------------------
losses/loss.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch_utils import training_stats
12
+ from torch_utils import misc
13
+ from torch_utils.ops import conv2d_gradfix
14
+ from losses.pcp import PerceptualLoss
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ class Loss:
19
+ def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
20
+ raise NotImplementedError()
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ class TwoStageLoss(Loss):
25
+ def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2, truncation_psi=1, pcp_ratio=1.0):
26
+ super().__init__()
27
+ self.device = device
28
+ self.G_mapping = G_mapping
29
+ self.G_synthesis = G_synthesis
30
+ self.D = D
31
+ self.augment_pipe = augment_pipe
32
+ self.style_mixing_prob = style_mixing_prob
33
+ self.r1_gamma = r1_gamma
34
+ self.pl_batch_shrink = pl_batch_shrink
35
+ self.pl_decay = pl_decay
36
+ self.pl_weight = pl_weight
37
+ self.pl_mean = torch.zeros([], device=device)
38
+ self.truncation_psi = truncation_psi
39
+ self.pcp = PerceptualLoss(layer_weights=dict(conv4_4=1/4, conv5_4=1/2)).to(device)
40
+ self.pcp_ratio = pcp_ratio
41
+
42
+ def run_G(self, img_in, mask_in, z, c, sync):
43
+ with misc.ddp_sync(self.G_mapping, sync):
44
+ ws = self.G_mapping(z, c, truncation_psi=self.truncation_psi)
45
+ if self.style_mixing_prob > 0:
46
+ with torch.autograd.profiler.record_function('style_mixing'):
47
+ cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
48
+ cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
49
+ ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, truncation_psi=self.truncation_psi, skip_w_avg_update=True)[:, cutoff:]
50
+ with misc.ddp_sync(self.G_synthesis, sync):
51
+ img, img_stg1 = self.G_synthesis(img_in, mask_in, ws, return_stg1=True)
52
+ return img, ws, img_stg1
53
+
54
+ def run_D(self, img, mask, img_stg1, c, sync):
55
+ # if self.augment_pipe is not None:
56
+ # # img = self.augment_pipe(img)
57
+ # # !!!!! have to remove the color transform
58
+ # tmp_img = torch.cat([img, mask], dim=1)
59
+ # tmp_img = self.augment_pipe(tmp_img)
60
+ # img, mask = torch.split(tmp_img, [3, 1])
61
+ with misc.ddp_sync(self.D, sync):
62
+ logits, logits_stg1 = self.D(img, mask, img_stg1, c)
63
+ return logits, logits_stg1
64
+
65
+ def accumulate_gradients(self, phase, real_img, mask, real_c, gen_z, gen_c, sync, gain):
66
+ assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
67
+ do_Gmain = (phase in ['Gmain', 'Gboth'])
68
+ do_Dmain = (phase in ['Dmain', 'Dboth'])
69
+ do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
70
+ do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
71
+
72
+ # Gmain: Maximize logits for generated images.
73
+ if do_Gmain:
74
+ with torch.autograd.profiler.record_function('Gmain_forward'):
75
+ gen_img, _gen_ws, gen_img_stg1 = self.run_G(real_img, mask, gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl.
76
+ gen_logits, gen_logits_stg1 = self.run_D(gen_img, mask, gen_img_stg1, gen_c, sync=False)
77
+ training_stats.report('Loss/scores/fake', gen_logits)
78
+ training_stats.report('Loss/signs/fake', gen_logits.sign())
79
+ training_stats.report('Loss/scores/fake_s1', gen_logits_stg1)
80
+ training_stats.report('Loss/signs/fake_s1', gen_logits_stg1.sign())
81
+ loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
82
+ training_stats.report('Loss/G/loss', loss_Gmain)
83
+ loss_Gmain_stg1 = torch.nn.functional.softplus(-gen_logits_stg1)
84
+ training_stats.report('Loss/G/loss_s1', loss_Gmain_stg1)
85
+ # just for showing
86
+ l1_loss = torch.mean(torch.abs(gen_img - real_img))
87
+ training_stats.report('Loss/G/l1_loss', l1_loss)
88
+ pcp_loss, _ = self.pcp(gen_img, real_img)
89
+ training_stats.report('Loss/G/pcp_loss', pcp_loss)
90
+ with torch.autograd.profiler.record_function('Gmain_backward'):
91
+ loss_Gmain_all = loss_Gmain + loss_Gmain_stg1 + pcp_loss * self.pcp_ratio
92
+ loss_Gmain_all.mean().mul(gain).backward()
93
+
94
+ # # Gpl: Apply path length regularization.
95
+ # if do_Gpl:
96
+ # with torch.autograd.profiler.record_function('Gpl_forward'):
97
+ # batch_size = gen_z.shape[0] // self.pl_batch_shrink
98
+ # gen_img, gen_ws = self.run_G(real_img[:batch_size], mask[:batch_size], gen_z[:batch_size], gen_c[:batch_size], sync=sync)
99
+ # pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
100
+ # with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
101
+ # pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
102
+ # pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
103
+ # pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
104
+ # self.pl_mean.copy_(pl_mean.detach())
105
+ # pl_penalty = (pl_lengths - pl_mean).square()
106
+ # training_stats.report('Loss/pl_penalty', pl_penalty)
107
+ # loss_Gpl = pl_penalty * self.pl_weight
108
+ # training_stats.report('Loss/G/reg', loss_Gpl)
109
+ # with torch.autograd.profiler.record_function('Gpl_backward'):
110
+ # (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
111
+
112
+ # Dmain: Minimize logits for generated images.
113
+ loss_Dgen = 0
114
+ loss_Dgen_stg1 = 0
115
+ if do_Dmain:
116
+ with torch.autograd.profiler.record_function('Dgen_forward'):
117
+ gen_img, _gen_ws, gen_img_stg1 = self.run_G(real_img, mask, gen_z, gen_c, sync=False)
118
+ gen_logits, gen_logits_stg1 = self.run_D(gen_img, mask, gen_img_stg1, gen_c, sync=False) # Gets synced by loss_Dreal.
119
+ training_stats.report('Loss/scores/fake', gen_logits)
120
+ training_stats.report('Loss/signs/fake', gen_logits.sign())
121
+ loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
122
+ training_stats.report('Loss/scores/fake_s1', gen_logits_stg1)
123
+ training_stats.report('Loss/signs/fake_s1', gen_logits_stg1.sign())
124
+ loss_Dgen_stg1 = torch.nn.functional.softplus(gen_logits_stg1) # -log(1 - sigmoid(gen_logits))
125
+ with torch.autograd.profiler.record_function('Dgen_backward'):
126
+ loss_Dgen_all = loss_Dgen + loss_Dgen_stg1
127
+ loss_Dgen_all.mean().mul(gain).backward()
128
+
129
+ # Dmain: Maximize logits for real images.
130
+ # Dr1: Apply R1 regularization.
131
+ if do_Dmain or do_Dr1:
132
+ name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
133
+ with torch.autograd.profiler.record_function(name + '_forward'):
134
+ real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
135
+ mask_tmp = mask.detach().requires_grad_(do_Dr1)
136
+ real_img_tmp_stg1 = real_img.detach().requires_grad_(do_Dr1)
137
+ real_logits, real_logits_stg1 = self.run_D(real_img_tmp, mask_tmp, real_img_tmp_stg1, real_c, sync=sync)
138
+ training_stats.report('Loss/scores/real', real_logits)
139
+ training_stats.report('Loss/signs/real', real_logits.sign())
140
+ training_stats.report('Loss/scores/real_s1', real_logits_stg1)
141
+ training_stats.report('Loss/signs/real_s1', real_logits_stg1.sign())
142
+
143
+ loss_Dreal = 0
144
+ loss_Dreal_stg1 = 0
145
+ if do_Dmain:
146
+ loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
147
+ loss_Dreal_stg1 = torch.nn.functional.softplus(-real_logits_stg1) # -log(sigmoid(real_logits))
148
+ training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
149
+ training_stats.report('Loss/D/loss_s1', loss_Dgen_stg1 + loss_Dreal_stg1)
150
+
151
+ loss_Dr1 = 0
152
+ loss_Dr1_stg1 = 0
153
+ if do_Dr1:
154
+ with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
155
+ r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
156
+ r1_grads_stg1 = torch.autograd.grad(outputs=[real_logits_stg1.sum()], inputs=[real_img_tmp_stg1], create_graph=True, only_inputs=True)[0]
157
+ r1_penalty = r1_grads.square().sum([1,2,3])
158
+ loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
159
+ training_stats.report('Loss/r1_penalty', r1_penalty)
160
+ training_stats.report('Loss/D/reg', loss_Dr1)
161
+
162
+ r1_penalty_stg1 = r1_grads_stg1.square().sum([1, 2, 3])
163
+ loss_Dr1_stg1 = r1_penalty_stg1 * (self.r1_gamma / 2)
164
+ training_stats.report('Loss/r1_penalty_s1', r1_penalty_stg1)
165
+ training_stats.report('Loss/D/reg_s1', loss_Dr1_stg1)
166
+
167
+ with torch.autograd.profiler.record_function(name + '_backward'):
168
+ ((real_logits + real_logits_stg1) * 0 + loss_Dreal + loss_Dreal_stg1 + loss_Dr1 + loss_Dr1_stg1).mean().mul(gain).backward()
169
+
170
+ #----------------------------------------------------------------------------
losses/pcp.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from losses.vggNet import VGGFeatureExtractor
6
+ import numpy as np
7
+
8
+
9
+ class PerceptualLoss(nn.Module):
10
+ """Perceptual loss with commonly used style loss.
11
+
12
+ Args:
13
+ layer_weights (dict): The weight for each layer of vgg feature.
14
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
15
+ feature layer (before relu5_4) will be extracted with weight
16
+ 1.0 in calculting losses.
17
+ vgg_type (str): The type of vgg network used as feature extractor.
18
+ Default: 'vgg19'.
19
+ use_input_norm (bool): If True, normalize the input image in vgg.
20
+ Default: True.
21
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
22
+ loss will be calculated and the loss will multiplied by the
23
+ weight. Default: 1.0.
24
+ style_weight (float): If `style_weight > 0`, the style loss will be
25
+ calculated and the loss will multiplied by the weight.
26
+ Default: 0.
27
+ norm_img (bool): If True, the image will be normed to [0, 1]. Note that
28
+ this is different from the `use_input_norm` which norm the input in
29
+ in forward function of vgg according to the statistics of dataset.
30
+ Importantly, the input image must be in range [-1, 1].
31
+ Default: False.
32
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
33
+ """
34
+
35
+ def __init__(self,
36
+ layer_weights,
37
+ vgg_type='vgg19',
38
+ use_input_norm=True,
39
+ use_pcp_loss=True,
40
+ use_style_loss=False,
41
+ norm_img=True,
42
+ criterion='l1'):
43
+ super(PerceptualLoss, self).__init__()
44
+ self.norm_img = norm_img
45
+ self.use_pcp_loss = use_pcp_loss
46
+ self.use_style_loss = use_style_loss
47
+ self.layer_weights = layer_weights
48
+ self.vgg = VGGFeatureExtractor(
49
+ layer_name_list=list(layer_weights.keys()),
50
+ vgg_type=vgg_type,
51
+ use_input_norm=use_input_norm)
52
+
53
+ self.criterion_type = criterion
54
+ if self.criterion_type == 'l1':
55
+ self.criterion = torch.nn.L1Loss()
56
+ elif self.criterion_type == 'l2':
57
+ self.criterion = torch.nn.L2loss()
58
+ elif self.criterion_type == 'fro':
59
+ self.criterion = None
60
+ else:
61
+ raise NotImplementedError('%s criterion has not been supported.' % self.criterion_type)
62
+
63
+ def forward(self, x, gt):
64
+ """Forward function.
65
+
66
+ Args:
67
+ x (Tensor): Input tensor with shape (n, c, h, w).
68
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
69
+
70
+ Returns:
71
+ Tensor: Forward results.
72
+ """
73
+
74
+ if self.norm_img:
75
+ x = (x + 1.) * 0.5
76
+ gt = (gt + 1.) * 0.5
77
+
78
+ # extract vgg features
79
+ x_features = self.vgg(x)
80
+ gt_features = self.vgg(gt.detach())
81
+
82
+ # calculate perceptual loss
83
+ if self.use_pcp_loss:
84
+ percep_loss = 0
85
+ for k in x_features.keys():
86
+ if self.criterion_type == 'fro':
87
+ percep_loss += torch.norm(
88
+ x_features[k] - gt_features[k],
89
+ p='fro') * self.layer_weights[k]
90
+ else:
91
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
92
+ else:
93
+ percep_loss = None
94
+
95
+ # calculate style loss
96
+ if self.use_style_loss:
97
+ style_loss = 0
98
+ for k in x_features.keys():
99
+ if self.criterion_type == 'fro':
100
+ style_loss += torch.norm(
101
+ self._gram_mat(x_features[k]) -
102
+ self._gram_mat(gt_features[k]),
103
+ p='fro') * self.layer_weights[k]
104
+ else:
105
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) \
106
+ * self.layer_weights[k]
107
+ else:
108
+ style_loss = None
109
+
110
+ return percep_loss, style_loss
111
+
112
+ def _gram_mat(self, x):
113
+ """Calculate Gram matrix.
114
+
115
+ Args:
116
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
117
+
118
+ Returns:
119
+ torch.Tensor: Gram matrix.
120
+ """
121
+ n, c, h, w = x.size()
122
+ features = x.view(n, c, w * h)
123
+ features_t = features.transpose(1, 2)
124
+ gram = features.bmm(features_t) / (c * h * w)
125
+ return gram
126
+
losses/vggNet.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import DataParallel
6
+ from torch.nn.parallel import DistributedDataParallel
7
+ from torchvision.models import vgg as vgg
8
+
9
+
10
+ NAMES = {
11
+ 'vgg11': [
12
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2',
13
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1',
14
+ 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1',
15
+ 'conv5_2', 'relu5_2', 'pool5'
16
+ ],
17
+ 'vgg13': [
18
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
19
+ 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
20
+ 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
21
+ 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
22
+ ],
23
+ 'vgg16': [
24
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
25
+ 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
26
+ 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1',
27
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4',
28
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
29
+ 'pool5'
30
+ ],
31
+ 'vgg19': [
32
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
33
+ 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
34
+ 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4',
35
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
36
+ 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
37
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4',
38
+ 'pool5'
39
+ ]
40
+ }
41
+
42
+
43
+ # MODEL_PATH = {
44
+ # 'vgg19': 'losses/pretrained/vgg19-dcbb9e9d.pth'
45
+ # }
46
+
47
+
48
+ def load_model(model, model_path, strict=True, cpu=False):
49
+ if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel):
50
+ model = model.module
51
+ if cpu:
52
+ loaded_model = torch.load(model_path, map_location='cpu')
53
+ else:
54
+ loaded_model = torch.load(model_path)
55
+ model.load_state_dict(loaded_model, strict=strict)
56
+
57
+
58
+ def insert_bn(names):
59
+ """Insert bn layer after each conv.
60
+
61
+ Args:
62
+ names (list): The list of layer names.
63
+
64
+ Returns:
65
+ list: The list of layer names with bn layers.
66
+ """
67
+ names_bn = []
68
+ for name in names:
69
+ names_bn.append(name)
70
+ if 'conv' in name:
71
+ position = name.replace('conv', '')
72
+ names_bn.append('bn' + position)
73
+ return names_bn
74
+
75
+
76
+ class VGGFeatureExtractor(nn.Module):
77
+ """VGG network for feature extraction.
78
+
79
+ In this implementation, we allow users to choose whether use normalization
80
+ in the input feature and the type of vgg network. Note that the pretrained
81
+ path must fit the vgg type.
82
+
83
+ Args:
84
+ layer_name_list (list[str]): Forward function returns the corresponding
85
+ features according to the layer_name_list.
86
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
87
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
88
+ use_input_norm (bool): If True, normalize the input image. Importantly,
89
+ the input feature must in the range [0, 1]. Default: True.
90
+ requires_grad (bool): If true, the parameters of VGG network will be
91
+ optimized. Default: False.
92
+ remove_pooling (bool): If true, the max pooling operations in VGG net
93
+ will be removed. Default: False.
94
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
95
+ """
96
+
97
+ def __init__(self,
98
+ layer_name_list,
99
+ vgg_type='vgg19',
100
+ use_input_norm=True,
101
+ requires_grad=False,
102
+ remove_pooling=False,
103
+ pooling_stride=2):
104
+ super(VGGFeatureExtractor, self).__init__()
105
+
106
+ self.layer_name_list = layer_name_list
107
+ self.use_input_norm = use_input_norm
108
+
109
+ self.names = NAMES[vgg_type.replace('_bn', '')]
110
+ if 'bn' in vgg_type:
111
+ self.names = insert_bn(self.names)
112
+
113
+ # only borrow layers that will be used to avoid unused params
114
+ max_idx = 0
115
+ for v in layer_name_list:
116
+ idx = self.names.index(v)
117
+ if idx > max_idx:
118
+ max_idx = idx
119
+
120
+ features = getattr(vgg, vgg_type)(pretrained=True).features[:max_idx + 1]
121
+ # vgg_model = getattr(vgg, vgg_type)(pretrained=False)
122
+ # load_model(vgg_model, MODEL_PATH[vgg_type], strict=True)
123
+ # features = vgg_model.features[:max_idx + 1]
124
+
125
+ modified_net = OrderedDict()
126
+ for k, v in zip(self.names, features):
127
+ if 'pool' in k:
128
+ # if remove_pooling is true, pooling operation will be removed
129
+ if remove_pooling:
130
+ continue
131
+ else:
132
+ # in some cases, we may want to change the default stride
133
+ modified_net[k] = nn.MaxPool2d(
134
+ kernel_size=2, stride=pooling_stride)
135
+ else:
136
+ modified_net[k] = v
137
+
138
+ self.vgg_net = nn.Sequential(modified_net)
139
+
140
+ if not requires_grad:
141
+ self.vgg_net.eval()
142
+ for param in self.parameters():
143
+ param.requires_grad = False
144
+ else:
145
+ self.vgg_net.train()
146
+ for param in self.parameters():
147
+ param.requires_grad = True
148
+
149
+ if self.use_input_norm:
150
+ # the mean is for image with range [0, 1]
151
+ self.register_buffer(
152
+ 'mean',
153
+ torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
154
+ # the std is for image with range [0, 1]
155
+ self.register_buffer(
156
+ 'std',
157
+ torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
158
+
159
+ def forward(self, x):
160
+ """Forward function.
161
+
162
+ Args:
163
+ x (Tensor): Input tensor with shape (n, c, h, w).
164
+
165
+ Returns:
166
+ Tensor: Forward results.
167
+ """
168
+
169
+ if self.use_input_norm:
170
+ x = (x - self.mean) / self.std
171
+
172
+ output = {}
173
+ for key, layer in self.vgg_net._modules.items():
174
+ x = layer(x)
175
+ if key in self.layer_name_list:
176
+ output[key] = x.clone()
177
+
178
+ return output
metrics/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
metrics/frechet_inception_distance.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Frechet Inception Distance (FID) from the paper
10
+ "GANs trained by a two time-scale update rule converge to a local Nash
11
+ equilibrium". Matches the original implementation by Heusel et al. at
12
+ https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13
+
14
+ import numpy as np
15
+ import scipy.linalg
16
+ from . import metric_utils
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def compute_fid(opts, max_real, num_gen):
21
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
23
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24
+
25
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
26
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
28
+
29
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
30
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
32
+
33
+ if opts.rank != 0:
34
+ return float('nan')
35
+
36
+ m = np.square(mu_gen - mu_real).sum()
37
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
38
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
39
+ return float(fid)
40
+
41
+ #----------------------------------------------------------------------------
metrics/inception_discriminative_score.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import scipy.linalg
4
+ from . import metric_utils
5
+ import sklearn.svm
6
+
7
+ #----------------------------------------------------------------------------
8
+
9
+ def compute_ids(opts, max_real, num_gen):
10
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
11
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
12
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
13
+
14
+ real_activations = metric_utils.compute_feature_stats_for_dataset(
15
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
16
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
17
+
18
+ fake_activations = metric_utils.compute_feature_stats_for_generator(
19
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
20
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
21
+
22
+ if opts.rank != 0:
23
+ return float('nan')
24
+
25
+ svm = sklearn.svm.LinearSVC(dual=False)
26
+ svm_inputs = np.concatenate([real_activations, fake_activations])
27
+ svm_targets = np.array([1] * real_activations.shape[0] + [0] * fake_activations.shape[0])
28
+ print('Fitting ...')
29
+ svm.fit(svm_inputs, svm_targets)
30
+ u_ids = 1 - svm.score(svm_inputs, svm_targets)
31
+ real_outputs = svm.decision_function(real_activations)
32
+ fake_outputs = svm.decision_function(fake_activations)
33
+ p_ids = np.mean(fake_outputs > real_outputs)
34
+
35
+ return float(u_ids), float(p_ids)
36
+
37
+ #----------------------------------------------------------------------------
metrics/inception_score.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Inception Score (IS) from the paper "Improved techniques for training
10
+ GANs". Matches the original implementation by Salimans et al. at
11
+ https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_is(opts, num_gen, num_splits):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22
+
23
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
24
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
+ capture_all=True, max_items=num_gen).get_all()
26
+
27
+ if opts.rank != 0:
28
+ return float('nan'), float('nan')
29
+
30
+ scores = []
31
+ for i in range(num_splits):
32
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
33
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
34
+ kl = np.mean(np.sum(kl, axis=1))
35
+ scores.append(np.exp(kl))
36
+ return float(np.mean(scores)), float(np.std(scores))
37
+
38
+ #----------------------------------------------------------------------------
metrics/kernel_inception_distance.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10
+ GANs". Matches the original implementation by Binkowski et al. at
11
+ https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22
+
23
+ real_features = metric_utils.compute_feature_stats_for_dataset(
24
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
26
+
27
+ gen_features = metric_utils.compute_feature_stats_for_generator(
28
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30
+
31
+ if opts.rank != 0:
32
+ return float('nan')
33
+
34
+ n = real_features.shape[1]
35
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36
+ t = 0
37
+ for _subset_idx in range(num_subsets):
38
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41
+ b = (x @ y.T / n + 1) ** 3
42
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43
+ kid = t / num_subsets / m
44
+ return float(kid)
45
+
46
+ #----------------------------------------------------------------------------
metrics/metric_main.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import time
11
+ import json
12
+ import torch
13
+ import dnnlib
14
+
15
+ from . import metric_utils
16
+ from . import frechet_inception_distance
17
+ from . import kernel_inception_distance
18
+ from . import precision_recall
19
+ from . import perceptual_path_length
20
+ from . import inception_score
21
+ from . import psnr_ssim_l1
22
+ from . import inception_discriminative_score
23
+
24
+ #----------------------------------------------------------------------------
25
+
26
+ _metric_dict = dict() # name => fn
27
+
28
+ def register_metric(fn):
29
+ assert callable(fn)
30
+ _metric_dict[fn.__name__] = fn
31
+ return fn
32
+
33
+ def is_valid_metric(metric):
34
+ return metric in _metric_dict
35
+
36
+ def list_valid_metrics():
37
+ return list(_metric_dict.keys())
38
+
39
+ #----------------------------------------------------------------------------
40
+
41
+ def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
42
+ assert is_valid_metric(metric)
43
+ opts = metric_utils.MetricOptions(**kwargs)
44
+
45
+ # Calculate.
46
+ start_time = time.time()
47
+ results = _metric_dict[metric](opts)
48
+ total_time = time.time() - start_time
49
+
50
+ # Broadcast results.
51
+ for key, value in list(results.items()):
52
+ if opts.num_gpus > 1:
53
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
54
+ torch.distributed.broadcast(tensor=value, src=0)
55
+ value = float(value.cpu())
56
+ results[key] = value
57
+
58
+ # Decorate with metadata.
59
+ return dnnlib.EasyDict(
60
+ results = dnnlib.EasyDict(results),
61
+ metric = metric,
62
+ total_time = total_time,
63
+ total_time_str = dnnlib.util.format_time(total_time),
64
+ num_gpus = opts.num_gpus,
65
+ )
66
+
67
+ #----------------------------------------------------------------------------
68
+
69
+ def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
70
+ metric = result_dict['metric']
71
+ assert is_valid_metric(metric)
72
+ if run_dir is not None and snapshot_pkl is not None:
73
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
74
+
75
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
76
+ print(jsonl_line)
77
+ if run_dir is not None and os.path.isdir(run_dir):
78
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
79
+ f.write(jsonl_line + '\n')
80
+
81
+ #----------------------------------------------------------------------------
82
+ # Primary metrics.
83
+
84
+ @register_metric
85
+ def fid2993_full(opts):
86
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
87
+ fid = frechet_inception_distance.compute_fid(opts, max_real=2993, num_gen=2993)
88
+ return dict(fid2993_full=fid)
89
+
90
+ @register_metric
91
+ def fid36k5_full(opts):
92
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
93
+ fid = frechet_inception_distance.compute_fid(opts, max_real=36500, num_gen=36500)
94
+ return dict(fid36k5_full=fid)
95
+
96
+ @register_metric
97
+ def fid_places(opts):
98
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
99
+ fid = frechet_inception_distance.compute_fid(opts, max_real=36500, num_gen=36500)
100
+ return dict(fid36k5_full=fid)
101
+
102
+ @register_metric
103
+ def ids_places(opts):
104
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
105
+ u_ids, p_ids = inception_discriminative_score.compute_ids(opts, max_real=36500, num_gen=36500)
106
+ return dict(u_ids=u_ids, p_ids=p_ids)
107
+
108
+ @register_metric
109
+ def psnr36k5_full(opts):
110
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
111
+ psnr, ssim, l1 = psnr_ssim_l1.compute_psnr(opts, max_real=36500)
112
+ return dict(psnr=psnr, ssim=ssim, l1=l1)
113
+
114
+ @register_metric
115
+ def fid50k_full(opts):
116
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
117
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
118
+ return dict(fid50k_full=fid)
119
+
120
+ @register_metric
121
+ def kid50k_full(opts):
122
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
123
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
124
+ return dict(kid50k_full=kid)
125
+
126
+ @register_metric
127
+ def pr50k3_full(opts):
128
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
129
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
130
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
131
+
132
+ @register_metric
133
+ def ppl2_wend(opts):
134
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
135
+ return dict(ppl2_wend=ppl)
136
+
137
+ @register_metric
138
+ def is50k(opts):
139
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
140
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
141
+ return dict(is50k_mean=mean, is50k_std=std)
142
+
143
+ #----------------------------------------------------------------------------
144
+ # Legacy metrics.
145
+
146
+ @register_metric
147
+ def fid50k(opts):
148
+ opts.dataset_kwargs.update(max_size=None)
149
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
150
+ return dict(fid50k=fid)
151
+
152
+ @register_metric
153
+ def kid50k(opts):
154
+ opts.dataset_kwargs.update(max_size=None)
155
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
156
+ return dict(kid50k=kid)
157
+
158
+ @register_metric
159
+ def pr50k3(opts):
160
+ opts.dataset_kwargs.update(max_size=None)
161
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
162
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
163
+
164
+ @register_metric
165
+ def ppl_zfull(opts):
166
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
167
+ return dict(ppl_zfull=ppl)
168
+
169
+ @register_metric
170
+ def ppl_wfull(opts):
171
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
172
+ return dict(ppl_wfull=ppl)
173
+
174
+ @register_metric
175
+ def ppl_zend(opts):
176
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
177
+ return dict(ppl_zend=ppl)
178
+
179
+ @register_metric
180
+ def ppl_wend(opts):
181
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
182
+ return dict(ppl_wend=ppl)
183
+
184
+ #----------------------------------------------------------------------------
metrics/metric_utils.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import time
11
+ import hashlib
12
+ import pickle
13
+ import copy
14
+ import uuid
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+ import math
19
+ import cv2
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ class MetricOptions:
24
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
25
+ assert 0 <= rank < num_gpus
26
+ self.G = G
27
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
28
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
29
+ self.num_gpus = num_gpus
30
+ self.rank = rank
31
+ self.device = device if device is not None else torch.device('cuda', rank)
32
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
33
+ self.cache = cache
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ _feature_detector_cache = dict()
38
+
39
+ def get_feature_detector_name(url):
40
+ return os.path.splitext(url.split('/')[-1])[0]
41
+
42
+ def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
43
+ assert 0 <= rank < num_gpus
44
+ key = (url, device)
45
+ if key not in _feature_detector_cache:
46
+ is_leader = (rank == 0)
47
+ if not is_leader and num_gpus > 1:
48
+ torch.distributed.barrier() # leader goes first
49
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
50
+ _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
51
+ if is_leader and num_gpus > 1:
52
+ torch.distributed.barrier() # others follow
53
+ return _feature_detector_cache[key]
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ class FeatureStats:
58
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
59
+ self.capture_all = capture_all
60
+ self.capture_mean_cov = capture_mean_cov
61
+ self.max_items = max_items
62
+ self.num_items = 0
63
+ self.num_features = None
64
+ self.all_features = None
65
+ self.raw_mean = None
66
+ self.raw_cov = None
67
+
68
+ def set_num_features(self, num_features):
69
+ if self.num_features is not None:
70
+ assert num_features == self.num_features
71
+ else:
72
+ self.num_features = num_features
73
+ self.all_features = []
74
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
75
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
76
+
77
+ def is_full(self):
78
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
79
+
80
+ def append(self, x):
81
+ x = np.asarray(x, dtype=np.float32)
82
+ assert x.ndim == 2
83
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
84
+ if self.num_items >= self.max_items:
85
+ return
86
+ x = x[:self.max_items - self.num_items]
87
+
88
+ self.set_num_features(x.shape[1])
89
+ self.num_items += x.shape[0]
90
+ if self.capture_all:
91
+ self.all_features.append(x)
92
+ if self.capture_mean_cov:
93
+ x64 = x.astype(np.float64)
94
+ self.raw_mean += x64.sum(axis=0)
95
+ self.raw_cov += x64.T @ x64
96
+
97
+ def append_torch(self, x, num_gpus=1, rank=0):
98
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
99
+ assert 0 <= rank < num_gpus
100
+ if num_gpus > 1:
101
+ ys = []
102
+ for src in range(num_gpus):
103
+ y = x.clone()
104
+ torch.distributed.broadcast(y, src=src)
105
+ ys.append(y)
106
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
107
+ self.append(x.cpu().numpy())
108
+
109
+ def get_all(self):
110
+ assert self.capture_all
111
+ return np.concatenate(self.all_features, axis=0)
112
+
113
+ def get_all_torch(self):
114
+ return torch.from_numpy(self.get_all())
115
+
116
+ def get_mean_cov(self):
117
+ assert self.capture_mean_cov
118
+ mean = self.raw_mean / self.num_items
119
+ cov = self.raw_cov / self.num_items
120
+ cov = cov - np.outer(mean, mean)
121
+ return mean, cov
122
+
123
+ def save(self, pkl_file):
124
+ with open(pkl_file, 'wb') as f:
125
+ pickle.dump(self.__dict__, f)
126
+
127
+ @staticmethod
128
+ def load(pkl_file):
129
+ with open(pkl_file, 'rb') as f:
130
+ s = dnnlib.EasyDict(pickle.load(f))
131
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
132
+ obj.__dict__.update(s)
133
+ return obj
134
+
135
+ #----------------------------------------------------------------------------
136
+
137
+ class ProgressMonitor:
138
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
139
+ self.tag = tag
140
+ self.num_items = num_items
141
+ self.verbose = verbose
142
+ self.flush_interval = flush_interval
143
+ self.progress_fn = progress_fn
144
+ self.pfn_lo = pfn_lo
145
+ self.pfn_hi = pfn_hi
146
+ self.pfn_total = pfn_total
147
+ self.start_time = time.time()
148
+ self.batch_time = self.start_time
149
+ self.batch_items = 0
150
+ if self.progress_fn is not None:
151
+ self.progress_fn(self.pfn_lo, self.pfn_total)
152
+
153
+ def update(self, cur_items):
154
+ assert (self.num_items is None) or (cur_items <= self.num_items)
155
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
156
+ return
157
+ cur_time = time.time()
158
+ total_time = cur_time - self.start_time
159
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
160
+ if (self.verbose) and (self.tag is not None):
161
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
162
+ self.batch_time = cur_time
163
+ self.batch_items = cur_items
164
+
165
+ if (self.progress_fn is not None) and (self.num_items is not None):
166
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
167
+
168
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
169
+ return ProgressMonitor(
170
+ tag = tag,
171
+ num_items = num_items,
172
+ flush_interval = flush_interval,
173
+ verbose = self.verbose,
174
+ progress_fn = self.progress_fn,
175
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
176
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
177
+ pfn_total = self.pfn_total,
178
+ )
179
+
180
+ #----------------------------------------------------------------------------
181
+
182
+ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
183
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
184
+ if data_loader_kwargs is None:
185
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
186
+
187
+ # Try to lookup from cache.
188
+ cache_file = None
189
+ if opts.cache:
190
+ # Choose cache file name.
191
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
192
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
193
+ cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
194
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
195
+
196
+ # Check if the file exists (all processes must agree).
197
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
198
+ if opts.num_gpus > 1:
199
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
200
+ torch.distributed.broadcast(tensor=flag, src=0)
201
+ flag = (float(flag.cpu()) != 0)
202
+
203
+ # Load.
204
+ if flag:
205
+ return FeatureStats.load(cache_file)
206
+
207
+ # Initialize.
208
+ num_items = len(dataset)
209
+ if max_items is not None:
210
+ num_items = min(num_items, max_items)
211
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
212
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
213
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
214
+
215
+ # Main loop.
216
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
217
+ # for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
218
+ # adaptation to inpainting
219
+ for images, masks, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size,
220
+ **data_loader_kwargs):
221
+ # --------------------------------
222
+ if images.shape[1] == 1:
223
+ images = images.repeat([1, 3, 1, 1])
224
+ features = detector(images.to(opts.device), **detector_kwargs)
225
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
226
+ progress.update(stats.num_items)
227
+
228
+ # Save to cache.
229
+ if cache_file is not None and opts.rank == 0:
230
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
231
+ temp_file = cache_file + '.' + uuid.uuid4().hex
232
+ stats.save(temp_file)
233
+ os.replace(temp_file, cache_file) # atomic
234
+ return stats
235
+
236
+ #----------------------------------------------------------------------------
237
+
238
+ def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, data_loader_kwargs=None, **stats_kwargs):
239
+ if data_loader_kwargs is None:
240
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
241
+
242
+ if batch_gen is None:
243
+ batch_gen = min(batch_size, 4)
244
+ assert batch_size % batch_gen == 0
245
+
246
+ # Setup generator and load labels.
247
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
248
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
249
+
250
+ # Image generation func.
251
+ def run_generator(img_in, mask_in, z, c):
252
+ img = G(img_in, mask_in, z, c, **opts.G_kwargs)
253
+ # img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
254
+ img = ((img + 1.0) * 127.5).clamp(0, 255).round().to(torch.uint8)
255
+ return img
256
+
257
+ # # JIT.
258
+ # if jit:
259
+ # z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
260
+ # c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
261
+ # run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
262
+
263
+ # Initialize.
264
+ stats = FeatureStats(**stats_kwargs)
265
+ assert stats.max_items is not None
266
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
267
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
268
+
269
+ # Main loop.
270
+ item_subset = [(i * opts.num_gpus + opts.rank) % stats.max_items for i in range((stats.max_items - 1) // opts.num_gpus + 1)]
271
+ for imgs_batch, masks_batch, labels_batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset,
272
+ batch_size=batch_size,
273
+ **data_loader_kwargs):
274
+ images = []
275
+ imgs_gen = (imgs_batch.to(opts.device).to(torch.float32) / 127.5 - 1).split(batch_gen)
276
+ masks_gen = masks_batch.to(opts.device).to(torch.float32).split(batch_gen)
277
+ for img_in, mask_in in zip(imgs_gen, masks_gen):
278
+ z = torch.randn([img_in.shape[0], G.z_dim], device=opts.device)
279
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(img_in.shape[0])]
280
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
281
+ images.append(run_generator(img_in, mask_in, z, c))
282
+ images = torch.cat(images)
283
+ if images.shape[1] == 1:
284
+ images = images.repeat([1, 3, 1, 1])
285
+ features = detector(images, **detector_kwargs)
286
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
287
+ progress.update(stats.num_items)
288
+ return stats
289
+
290
+ #----------------------------------------------------------------------------
291
+
292
+ def compute_image_stats_for_generator(opts, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, data_loader_kwargs=None, **stats_kwargs):
293
+ if data_loader_kwargs is None:
294
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
295
+
296
+ if batch_gen is None:
297
+ batch_gen = min(batch_size, 4)
298
+ assert batch_size % batch_gen == 0
299
+
300
+ # Setup generator and load labels.
301
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
302
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
303
+
304
+ # Image generation func.
305
+ def run_generator(img_in, mask_in, z, c):
306
+ img = G(img_in, mask_in, z, c, **opts.G_kwargs)
307
+ # img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
308
+ img = ((img + 1.0) * 127.5).clamp(0, 255).round().to(torch.uint8)
309
+ return img
310
+
311
+ # Initialize.
312
+ stats = FeatureStats(**stats_kwargs)
313
+ assert stats.max_items is not None
314
+ progress = opts.progress.sub(tag='generator images', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
315
+
316
+ # Main loop.
317
+ item_subset = [(i * opts.num_gpus + opts.rank) % stats.max_items for i in range((stats.max_items - 1) // opts.num_gpus + 1)]
318
+ for imgs_batch, masks_batch, labels_batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset,
319
+ batch_size=batch_size,
320
+ **data_loader_kwargs):
321
+ images = []
322
+ imgs_gen = (imgs_batch.to(opts.device).to(torch.float32) / 127.5 - 1).split(batch_gen)
323
+ masks_gen = masks_batch.to(opts.device).to(torch.float32).split(batch_gen)
324
+ for img_in, mask_in in zip(imgs_gen, masks_gen):
325
+ z = torch.randn([img_in.shape[0], G.z_dim], device=opts.device)
326
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(img_in.shape[0])]
327
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
328
+ images.append(run_generator(img_in, mask_in, z, c))
329
+ images = torch.cat(images)
330
+ if images.shape[1] == 1:
331
+ images = images.repeat([1, 3, 1, 1])
332
+
333
+ assert imgs_batch.shape == images.shape
334
+ metrics = []
335
+ for i in range(imgs_batch.shape[0]):
336
+ img_real = np.transpose(imgs_batch[i].cpu().numpy(), [1, 2, 0])
337
+ img_gen = np.transpose(images[i].cpu().numpy(), [1, 2, 0])
338
+ psnr = calculate_psnr(img_gen, img_real)
339
+ ssim = calculate_ssim(img_gen, img_real)
340
+ l1 = calculate_l1(img_gen, img_real)
341
+ metrics.append([psnr, ssim, l1])
342
+ metrics = torch.from_numpy(np.array(metrics)).to(torch.float32).to(opts.device)
343
+
344
+ stats.append_torch(metrics, num_gpus=opts.num_gpus, rank=opts.rank)
345
+ progress.update(stats.num_items)
346
+ return stats
347
+
348
+
349
+ def calculate_psnr(img1, img2):
350
+ # img1 and img2 have range [0, 255]
351
+ img1 = img1.astype(np.float64)
352
+ img2 = img2.astype(np.float64)
353
+ mse = np.mean((img1 - img2) ** 2)
354
+ if mse == 0:
355
+ return float('inf')
356
+
357
+ return 20 * math.log10(255.0 / math.sqrt(mse))
358
+
359
+
360
+ def calculate_ssim(img1, img2):
361
+ C1 = (0.01 * 255) ** 2
362
+ C2 = (0.03 * 255) ** 2
363
+
364
+ img1 = img1.astype(np.float64)
365
+ img2 = img2.astype(np.float64)
366
+ kernel = cv2.getGaussianKernel(11, 1.5)
367
+ window = np.outer(kernel, kernel.transpose())
368
+
369
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
370
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
371
+ mu1_sq = mu1 ** 2
372
+ mu2_sq = mu2 ** 2
373
+ mu1_mu2 = mu1 * mu2
374
+ sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
375
+ sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
376
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
377
+
378
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
379
+
380
+ return ssim_map.mean()
381
+
382
+
383
+ def calculate_l1(img1, img2):
384
+ img1 = img1.astype(np.float64) / 255.0
385
+ img2 = img2.astype(np.float64) / 255.0
386
+ l1 = np.mean(np.abs(img1 - img2))
387
+
388
+ return l1
389
+
390
+
391
+ # def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
392
+ # if batch_gen is None:
393
+ # batch_gen = min(batch_size, 4)
394
+ # assert batch_size % batch_gen == 0
395
+ #
396
+ # # Setup generator and load labels.
397
+ # G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
398
+ # dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
399
+ #
400
+ # # Image generation func.
401
+ # def run_generator(z, c):
402
+ # img = G(z=z, c=c, **opts.G_kwargs)
403
+ # img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
404
+ # return img
405
+ #
406
+ # # JIT.
407
+ # if jit:
408
+ # z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
409
+ # c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
410
+ # run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
411
+ #
412
+ # # Initialize.
413
+ # stats = FeatureStats(**stats_kwargs)
414
+ # assert stats.max_items is not None
415
+ # progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
416
+ # detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
417
+ #
418
+ # # Main loop.
419
+ # while not stats.is_full():
420
+ # images = []
421
+ # for _i in range(batch_size // batch_gen):
422
+ # z = torch.randn([batch_gen, G.z_dim], device=opts.device)
423
+ # c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
424
+ # c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
425
+ # images.append(run_generator(z, c))
426
+ # images = torch.cat(images)
427
+ # if images.shape[1] == 1:
428
+ # images = images.repeat([1, 3, 1, 1])
429
+ # features = detector(images, **detector_kwargs)
430
+ # stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
431
+ # progress.update(stats.num_items)
432
+ # return stats
433
+ #
434
+ # #----------------------------------------------------------------------------
metrics/perceptual_path_length.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10
+ Architecture for Generative Adversarial Networks". Matches the original
11
+ implementation by Karras et al. at
12
+ https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13
+
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+ from . import metric_utils
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ # Spherical interpolation of a batch of vectors.
23
+ def slerp(a, b, t):
24
+ a = a / a.norm(dim=-1, keepdim=True)
25
+ b = b / b.norm(dim=-1, keepdim=True)
26
+ d = (a * b).sum(dim=-1, keepdim=True)
27
+ p = t * torch.acos(d)
28
+ c = b - d * a
29
+ c = c / c.norm(dim=-1, keepdim=True)
30
+ d = a * torch.cos(p) + c * torch.sin(p)
31
+ d = d / d.norm(dim=-1, keepdim=True)
32
+ return d
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ class PPLSampler(torch.nn.Module):
37
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
38
+ assert space in ['z', 'w']
39
+ assert sampling in ['full', 'end']
40
+ super().__init__()
41
+ self.G = copy.deepcopy(G)
42
+ self.G_kwargs = G_kwargs
43
+ self.epsilon = epsilon
44
+ self.space = space
45
+ self.sampling = sampling
46
+ self.crop = crop
47
+ self.vgg16 = copy.deepcopy(vgg16)
48
+
49
+ def forward(self, c):
50
+ # Generate random latents and interpolation t-values.
51
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
52
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
53
+
54
+ # Interpolate in W or Z.
55
+ if self.space == 'w':
56
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
57
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
58
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
59
+ else: # space == 'z'
60
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
61
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
62
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
63
+
64
+ # Randomize noise buffers.
65
+ for name, buf in self.G.named_buffers():
66
+ if name.endswith('.noise_const'):
67
+ buf.copy_(torch.randn_like(buf))
68
+
69
+ # Generate images.
70
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
71
+
72
+ # Center crop.
73
+ if self.crop:
74
+ assert img.shape[2] == img.shape[3]
75
+ c = img.shape[2] // 8
76
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
77
+
78
+ # Downsample to 256x256.
79
+ factor = self.G.img_resolution // 256
80
+ if factor > 1:
81
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
82
+
83
+ # Scale dynamic range from [-1,1] to [0,255].
84
+ img = (img + 1) * (255 / 2)
85
+ if self.G.img_channels == 1:
86
+ img = img.repeat([1, 3, 1, 1])
87
+
88
+ # Evaluate differential LPIPS.
89
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
90
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
91
+ return dist
92
+
93
+ #----------------------------------------------------------------------------
94
+
95
+ def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
96
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
97
+ vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
98
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
99
+
100
+ # Setup sampler.
101
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
102
+ sampler.eval().requires_grad_(False).to(opts.device)
103
+ if jit:
104
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
105
+ sampler = torch.jit.trace(sampler, [c], check_trace=False)
106
+
107
+ # Sampling loop.
108
+ dist = []
109
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
110
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
111
+ progress.update(batch_start)
112
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
113
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
114
+ x = sampler(c)
115
+ for src in range(opts.num_gpus):
116
+ y = x.clone()
117
+ if opts.num_gpus > 1:
118
+ torch.distributed.broadcast(y, src=src)
119
+ dist.append(y)
120
+ progress.update(num_samples)
121
+
122
+ # Compute PPL.
123
+ if opts.rank != 0:
124
+ return float('nan')
125
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
126
+ lo = np.percentile(dist, 1, interpolation='lower')
127
+ hi = np.percentile(dist, 99, interpolation='higher')
128
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
129
+ return float(ppl)
130
+
131
+ #----------------------------------------------------------------------------
metrics/precision_recall.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Precision/Recall (PR) from the paper "Improved Precision and Recall
10
+ Metric for Assessing Generative Models". Matches the original implementation
11
+ by Kynkaanniemi et al. at
12
+ https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
13
+
14
+ import torch
15
+ from . import metric_utils
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
20
+ assert 0 <= rank < num_gpus
21
+ num_cols = col_features.shape[0]
22
+ num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
23
+ col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
24
+ dist_batches = []
25
+ for col_batch in col_batches[rank :: num_gpus]:
26
+ dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
27
+ for src in range(num_gpus):
28
+ dist_broadcast = dist_batch.clone()
29
+ if num_gpus > 1:
30
+ torch.distributed.broadcast(dist_broadcast, src=src)
31
+ dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
32
+ return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
37
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
38
+ detector_kwargs = dict(return_features=True)
39
+
40
+ real_features = metric_utils.compute_feature_stats_for_dataset(
41
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
42
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
43
+
44
+ gen_features = metric_utils.compute_feature_stats_for_generator(
45
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
46
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
47
+
48
+ results = dict()
49
+ for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
50
+ kth = []
51
+ for manifold_batch in manifold.split(row_batch_size):
52
+ dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
53
+ kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
54
+ kth = torch.cat(kth) if opts.rank == 0 else None
55
+ pred = []
56
+ for probes_batch in probes.split(row_batch_size):
57
+ dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
58
+ pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
59
+ results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
60
+ return results['precision'], results['recall']
61
+
62
+ #----------------------------------------------------------------------------
metrics/psnr_ssim_l1.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.linalg
3
+ from . import metric_utils
4
+ import math
5
+ import cv2
6
+
7
+
8
+ def compute_psnr(opts, max_real):
9
+ # stats: numpy, [N, 3]
10
+ stats = metric_utils.compute_image_stats_for_generator(opts=opts, capture_all=True, max_items=max_real).get_all()
11
+
12
+ if opts.rank != 0:
13
+ return float('nan'), float('nan'), float('nan')
14
+
15
+ print('Number of samples: %d' % stats.shape[0])
16
+ avg_psnr = stats[:, 0].sum() / stats.shape[0]
17
+ avg_ssim = stats[:, 1].sum() / stats.shape[0]
18
+ avg_l1 = stats[:, 2].sum() / stats.shape[0]
19
+ return avg_psnr, avg_ssim, avg_l1
models/Places_512_FullData+LAION300k.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0230b8b39287e4a1ec4c53a7c724188cf0fe6dab2610bf79cdff3756b8517291
3
+ size 661315824
models/Places_512_FullData.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d960c4e6b3266b6b9fa74ee4458a9482160d54c06d7738696bc9a9e2b34c66dc
3
+ size 661420475
networks/basic_module.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, '../')
3
+ from collections import OrderedDict
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch_utils import misc
10
+ from torch_utils import persistence
11
+ from torch_utils.ops import conv2d_resample
12
+ from torch_utils.ops import upfirdn2d
13
+ from torch_utils.ops import bias_act
14
+
15
+ #----------------------------------------------------------------------------
16
+
17
+ @misc.profiled_function
18
+ def normalize_2nd_moment(x, dim=1, eps=1e-8):
19
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ @persistence.persistent_class
24
+ class FullyConnectedLayer(nn.Module):
25
+ def __init__(self,
26
+ in_features, # Number of input features.
27
+ out_features, # Number of output features.
28
+ bias = True, # Apply additive bias before the activation function?
29
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
30
+ lr_multiplier = 1, # Learning rate multiplier.
31
+ bias_init = 0, # Initial value for the additive bias.
32
+ ):
33
+ super().__init__()
34
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
35
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
36
+ self.activation = activation
37
+
38
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
39
+ self.bias_gain = lr_multiplier
40
+
41
+ def forward(self, x):
42
+ w = self.weight * self.weight_gain
43
+ b = self.bias
44
+ if b is not None and self.bias_gain != 1:
45
+ b = b * self.bias_gain
46
+
47
+ if self.activation == 'linear' and b is not None:
48
+ # out = torch.addmm(b.unsqueeze(0), x, w.t())
49
+ x = x.matmul(w.t())
50
+ out = x + b.reshape([-1 if i == x.ndim-1 else 1 for i in range(x.ndim)])
51
+ else:
52
+ x = x.matmul(w.t())
53
+ out = bias_act.bias_act(x, b, act=self.activation, dim=x.ndim-1)
54
+ return out
55
+
56
+ #----------------------------------------------------------------------------
57
+
58
+ @persistence.persistent_class
59
+ class Conv2dLayer(nn.Module):
60
+ def __init__(self,
61
+ in_channels, # Number of input channels.
62
+ out_channels, # Number of output channels.
63
+ kernel_size, # Width and height of the convolution kernel.
64
+ bias = True, # Apply additive bias before the activation function?
65
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
66
+ up = 1, # Integer upsampling factor.
67
+ down = 1, # Integer downsampling factor.
68
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
69
+ conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
70
+ trainable = True, # Update the weights of this layer during training?
71
+ ):
72
+ super().__init__()
73
+ self.activation = activation
74
+ self.up = up
75
+ self.down = down
76
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
77
+ self.conv_clamp = conv_clamp
78
+ self.padding = kernel_size // 2
79
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
80
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
81
+
82
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size])
83
+ bias = torch.zeros([out_channels]) if bias else None
84
+ if trainable:
85
+ self.weight = torch.nn.Parameter(weight)
86
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
87
+ else:
88
+ self.register_buffer('weight', weight)
89
+ if bias is not None:
90
+ self.register_buffer('bias', bias)
91
+ else:
92
+ self.bias = None
93
+
94
+ def forward(self, x, gain=1):
95
+ w = self.weight * self.weight_gain
96
+ x = conv2d_resample.conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down,
97
+ padding=self.padding)
98
+
99
+ act_gain = self.act_gain * gain
100
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
101
+ out = bias_act.bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
102
+ return out
103
+
104
+ #----------------------------------------------------------------------------
105
+
106
+ @persistence.persistent_class
107
+ class ModulatedConv2d(nn.Module):
108
+ def __init__(self,
109
+ in_channels, # Number of input channels.
110
+ out_channels, # Number of output channels.
111
+ kernel_size, # Width and height of the convolution kernel.
112
+ style_dim, # dimension of the style code
113
+ demodulate=True, # perfrom demodulation
114
+ up=1, # Integer upsampling factor.
115
+ down=1, # Integer downsampling factor.
116
+ resample_filter=[1,3,3,1], # Low-pass filter to apply when resampling activations.
117
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
118
+ ):
119
+ super().__init__()
120
+ self.demodulate = demodulate
121
+
122
+ self.weight = torch.nn.Parameter(torch.randn([1, out_channels, in_channels, kernel_size, kernel_size]))
123
+ self.out_channels = out_channels
124
+ self.kernel_size = kernel_size
125
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
126
+ self.padding = self.kernel_size // 2
127
+ self.up = up
128
+ self.down = down
129
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
130
+ self.conv_clamp = conv_clamp
131
+
132
+ self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1)
133
+
134
+ def forward(self, x, style):
135
+ batch, in_channels, height, width = x.shape
136
+ style = self.affine(style).view(batch, 1, in_channels, 1, 1)
137
+ weight = self.weight * self.weight_gain * style
138
+
139
+ if self.demodulate:
140
+ decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt()
141
+ weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1)
142
+
143
+ weight = weight.view(batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size)
144
+ x = x.view(1, batch * in_channels, height, width)
145
+ x = conv2d_resample.conv2d_resample(x=x, w=weight, f=self.resample_filter, up=self.up, down=self.down,
146
+ padding=self.padding, groups=batch)
147
+ out = x.view(batch, self.out_channels, *x.shape[2:])
148
+
149
+ return out
150
+
151
+ #----------------------------------------------------------------------------
152
+
153
+ @persistence.persistent_class
154
+ class StyleConv(torch.nn.Module):
155
+ def __init__(self,
156
+ in_channels, # Number of input channels.
157
+ out_channels, # Number of output channels.
158
+ style_dim, # Intermediate latent (W) dimensionality.
159
+ resolution, # Resolution of this layer.
160
+ kernel_size = 3, # Convolution kernel size.
161
+ up = 1, # Integer upsampling factor.
162
+ use_noise = True, # Enable noise input?
163
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
164
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
165
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
166
+ demodulate = True, # perform demodulation
167
+ ):
168
+ super().__init__()
169
+
170
+ self.conv = ModulatedConv2d(in_channels=in_channels,
171
+ out_channels=out_channels,
172
+ kernel_size=kernel_size,
173
+ style_dim=style_dim,
174
+ demodulate=demodulate,
175
+ up=up,
176
+ resample_filter=resample_filter,
177
+ conv_clamp=conv_clamp)
178
+
179
+ self.use_noise = use_noise
180
+ self.resolution = resolution
181
+ if use_noise:
182
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
183
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
184
+
185
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
186
+ self.activation = activation
187
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
188
+ self.conv_clamp = conv_clamp
189
+
190
+ def forward(self, x, style, noise_mode='random', gain=1):
191
+ x = self.conv(x, style)
192
+
193
+ assert noise_mode in ['random', 'const', 'none']
194
+
195
+ if self.use_noise:
196
+ if noise_mode == 'random':
197
+ xh, xw = x.size()[-2:]
198
+ noise = torch.randn([x.shape[0], 1, xh, xw], device=x.device) \
199
+ * self.noise_strength
200
+ if noise_mode == 'const':
201
+ noise = self.noise_const * self.noise_strength
202
+ x = x + noise
203
+
204
+ act_gain = self.act_gain * gain
205
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
206
+ out = bias_act.bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
207
+
208
+ return out
209
+
210
+ #----------------------------------------------------------------------------
211
+
212
+ @persistence.persistent_class
213
+ class ToRGB(torch.nn.Module):
214
+ def __init__(self,
215
+ in_channels,
216
+ out_channels,
217
+ style_dim,
218
+ kernel_size=1,
219
+ resample_filter=[1,3,3,1],
220
+ conv_clamp=None,
221
+ demodulate=False):
222
+ super().__init__()
223
+
224
+ self.conv = ModulatedConv2d(in_channels=in_channels,
225
+ out_channels=out_channels,
226
+ kernel_size=kernel_size,
227
+ style_dim=style_dim,
228
+ demodulate=demodulate,
229
+ resample_filter=resample_filter,
230
+ conv_clamp=conv_clamp)
231
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
232
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
233
+ self.conv_clamp = conv_clamp
234
+
235
+ def forward(self, x, style, skip=None):
236
+ x = self.conv(x, style)
237
+ out = bias_act.bias_act(x, self.bias, clamp=self.conv_clamp)
238
+
239
+ if skip is not None:
240
+ if skip.shape != out.shape:
241
+ skip = upfirdn2d.upsample2d(skip, self.resample_filter)
242
+ out = out + skip
243
+
244
+ return out
245
+
246
+ #----------------------------------------------------------------------------
247
+
248
+ @misc.profiled_function
249
+ def get_style_code(a, b):
250
+ return torch.cat([a, b], dim=1)
251
+
252
+ #----------------------------------------------------------------------------
253
+
254
+ @persistence.persistent_class
255
+ class DecBlockFirst(nn.Module):
256
+ def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
257
+ super().__init__()
258
+ self.fc = FullyConnectedLayer(in_features=in_channels*2,
259
+ out_features=in_channels*4**2,
260
+ activation=activation)
261
+ self.conv = StyleConv(in_channels=in_channels,
262
+ out_channels=out_channels,
263
+ style_dim=style_dim,
264
+ resolution=4,
265
+ kernel_size=3,
266
+ use_noise=use_noise,
267
+ activation=activation,
268
+ demodulate=demodulate,
269
+ )
270
+ self.toRGB = ToRGB(in_channels=out_channels,
271
+ out_channels=img_channels,
272
+ style_dim=style_dim,
273
+ kernel_size=1,
274
+ demodulate=False,
275
+ )
276
+
277
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
278
+ x = self.fc(x).view(x.shape[0], -1, 4, 4)
279
+ x = x + E_features[2]
280
+ style = get_style_code(ws[:, 0], gs)
281
+ x = self.conv(x, style, noise_mode=noise_mode)
282
+ style = get_style_code(ws[:, 1], gs)
283
+ img = self.toRGB(x, style, skip=None)
284
+
285
+ return x, img
286
+
287
+
288
+ @persistence.persistent_class
289
+ class DecBlockFirstV2(nn.Module):
290
+ def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
291
+ super().__init__()
292
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
293
+ out_channels=in_channels,
294
+ kernel_size=3,
295
+ activation=activation,
296
+ )
297
+ self.conv1 = StyleConv(in_channels=in_channels,
298
+ out_channels=out_channels,
299
+ style_dim=style_dim,
300
+ resolution=4,
301
+ kernel_size=3,
302
+ use_noise=use_noise,
303
+ activation=activation,
304
+ demodulate=demodulate,
305
+ )
306
+ self.toRGB = ToRGB(in_channels=out_channels,
307
+ out_channels=img_channels,
308
+ style_dim=style_dim,
309
+ kernel_size=1,
310
+ demodulate=False,
311
+ )
312
+
313
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
314
+ # x = self.fc(x).view(x.shape[0], -1, 4, 4)
315
+ x = self.conv0(x)
316
+ x = x + E_features[2]
317
+ style = get_style_code(ws[:, 0], gs)
318
+ x = self.conv1(x, style, noise_mode=noise_mode)
319
+ style = get_style_code(ws[:, 1], gs)
320
+ img = self.toRGB(x, style, skip=None)
321
+
322
+ return x, img
323
+
324
+ #----------------------------------------------------------------------------
325
+
326
+ @persistence.persistent_class
327
+ class DecBlock(nn.Module):
328
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): # res = 2, ..., resolution_log2
329
+ super().__init__()
330
+ self.res = res
331
+
332
+ self.conv0 = StyleConv(in_channels=in_channels,
333
+ out_channels=out_channels,
334
+ style_dim=style_dim,
335
+ resolution=2**res,
336
+ kernel_size=3,
337
+ up=2,
338
+ use_noise=use_noise,
339
+ activation=activation,
340
+ demodulate=demodulate,
341
+ )
342
+ self.conv1 = StyleConv(in_channels=out_channels,
343
+ out_channels=out_channels,
344
+ style_dim=style_dim,
345
+ resolution=2**res,
346
+ kernel_size=3,
347
+ use_noise=use_noise,
348
+ activation=activation,
349
+ demodulate=demodulate,
350
+ )
351
+ self.toRGB = ToRGB(in_channels=out_channels,
352
+ out_channels=img_channels,
353
+ style_dim=style_dim,
354
+ kernel_size=1,
355
+ demodulate=False,
356
+ )
357
+
358
+ def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
359
+ style = get_style_code(ws[:, self.res * 2 - 5], gs)
360
+ x = self.conv0(x, style, noise_mode=noise_mode)
361
+ x = x + E_features[self.res]
362
+ style = get_style_code(ws[:, self.res * 2 - 4], gs)
363
+ x = self.conv1(x, style, noise_mode=noise_mode)
364
+ style = get_style_code(ws[:, self.res * 2 - 3], gs)
365
+ img = self.toRGB(x, style, skip=img)
366
+
367
+ return x, img
368
+
369
+ #----------------------------------------------------------------------------
370
+
371
+ @persistence.persistent_class
372
+ class MappingNet(torch.nn.Module):
373
+ def __init__(self,
374
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
375
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
376
+ w_dim, # Intermediate latent (W) dimensionality.
377
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
378
+ num_layers = 8, # Number of mapping layers.
379
+ embed_features = None, # Label embedding dimensionality, None = same as w_dim.
380
+ layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
381
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
382
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
383
+ w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
384
+ ):
385
+ super().__init__()
386
+ self.z_dim = z_dim
387
+ self.c_dim = c_dim
388
+ self.w_dim = w_dim
389
+ self.num_ws = num_ws
390
+ self.num_layers = num_layers
391
+ self.w_avg_beta = w_avg_beta
392
+
393
+ if embed_features is None:
394
+ embed_features = w_dim
395
+ if c_dim == 0:
396
+ embed_features = 0
397
+ if layer_features is None:
398
+ layer_features = w_dim
399
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
400
+
401
+ if c_dim > 0:
402
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
403
+ for idx in range(num_layers):
404
+ in_features = features_list[idx]
405
+ out_features = features_list[idx + 1]
406
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
407
+ setattr(self, f'fc{idx}', layer)
408
+
409
+ if num_ws is not None and w_avg_beta is not None:
410
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
411
+
412
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
413
+ # Embed, normalize, and concat inputs.
414
+ x = None
415
+ with torch.autograd.profiler.record_function('input'):
416
+ if self.z_dim > 0:
417
+ x = normalize_2nd_moment(z.to(torch.float32))
418
+ if self.c_dim > 0:
419
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
420
+ x = torch.cat([x, y], dim=1) if x is not None else y
421
+
422
+ # Main layers.
423
+ for idx in range(self.num_layers):
424
+ layer = getattr(self, f'fc{idx}')
425
+ x = layer(x)
426
+
427
+ # Update moving average of W.
428
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
429
+ with torch.autograd.profiler.record_function('update_w_avg'):
430
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
431
+
432
+ # Broadcast.
433
+ if self.num_ws is not None:
434
+ with torch.autograd.profiler.record_function('broadcast'):
435
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
436
+
437
+ # Apply truncation.
438
+ if truncation_psi != 1:
439
+ with torch.autograd.profiler.record_function('truncate'):
440
+ assert self.w_avg_beta is not None
441
+ if self.num_ws is None or truncation_cutoff is None:
442
+ x = self.w_avg.lerp(x, truncation_psi)
443
+ else:
444
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
445
+
446
+ return x
447
+
448
+ #----------------------------------------------------------------------------
449
+
450
+ @persistence.persistent_class
451
+ class DisFromRGB(nn.Module):
452
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
453
+ super().__init__()
454
+ self.conv = Conv2dLayer(in_channels=in_channels,
455
+ out_channels=out_channels,
456
+ kernel_size=1,
457
+ activation=activation,
458
+ )
459
+
460
+ def forward(self, x):
461
+ return self.conv(x)
462
+
463
+ #----------------------------------------------------------------------------
464
+
465
+ @persistence.persistent_class
466
+ class DisBlock(nn.Module):
467
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
468
+ super().__init__()
469
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
470
+ out_channels=in_channels,
471
+ kernel_size=3,
472
+ activation=activation,
473
+ )
474
+ self.conv1 = Conv2dLayer(in_channels=in_channels,
475
+ out_channels=out_channels,
476
+ kernel_size=3,
477
+ down=2,
478
+ activation=activation,
479
+ )
480
+ self.skip = Conv2dLayer(in_channels=in_channels,
481
+ out_channels=out_channels,
482
+ kernel_size=1,
483
+ down=2,
484
+ bias=False,
485
+ )
486
+
487
+ def forward(self, x):
488
+ skip = self.skip(x, gain=np.sqrt(0.5))
489
+ x = self.conv0(x)
490
+ x = self.conv1(x, gain=np.sqrt(0.5))
491
+ out = skip + x
492
+
493
+ return out
494
+
495
+ #----------------------------------------------------------------------------
496
+
497
+ @persistence.persistent_class
498
+ class MinibatchStdLayer(torch.nn.Module):
499
+ def __init__(self, group_size, num_channels=1):
500
+ super().__init__()
501
+ self.group_size = group_size
502
+ self.num_channels = num_channels
503
+
504
+ def forward(self, x):
505
+ N, C, H, W = x.shape
506
+ with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
507
+ G = torch.min(torch.as_tensor(self.group_size),
508
+ torch.as_tensor(N)) if self.group_size is not None else N
509
+ F = self.num_channels
510
+ c = C // F
511
+
512
+ y = x.reshape(G, -1, F, c, H,
513
+ W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
514
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
515
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
516
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
517
+ y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.
518
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
519
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
520
+ x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
521
+ return x
522
+
523
+ #----------------------------------------------------------------------------
524
+
525
+ @persistence.persistent_class
526
+ class Discriminator(torch.nn.Module):
527
+ def __init__(self,
528
+ c_dim, # Conditioning label (C) dimensionality.
529
+ img_resolution, # Input resolution.
530
+ img_channels, # Number of input color channels.
531
+ channel_base = 32768, # Overall multiplier for the number of channels.
532
+ channel_max = 512, # Maximum number of channels in any layer.
533
+ channel_decay = 1,
534
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
535
+ activation = 'lrelu',
536
+ mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
537
+ mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
538
+ ):
539
+ super().__init__()
540
+ self.c_dim = c_dim
541
+ self.img_resolution = img_resolution
542
+ self.img_channels = img_channels
543
+
544
+ resolution_log2 = int(np.log2(img_resolution))
545
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
546
+ self.resolution_log2 = resolution_log2
547
+
548
+ def nf(stage):
549
+ return np.clip(int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max)
550
+
551
+ if cmap_dim == None:
552
+ cmap_dim = nf(2)
553
+ if c_dim == 0:
554
+ cmap_dim = 0
555
+ self.cmap_dim = cmap_dim
556
+
557
+ if c_dim > 0:
558
+ self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
559
+
560
+ Dis = [DisFromRGB(img_channels+1, nf(resolution_log2), activation)]
561
+ for res in range(resolution_log2, 2, -1):
562
+ Dis.append(DisBlock(nf(res), nf(res-1), activation))
563
+
564
+ if mbstd_num_channels > 0:
565
+ Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
566
+ Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
567
+ self.Dis = nn.Sequential(*Dis)
568
+
569
+ self.fc0 = FullyConnectedLayer(nf(2)*4**2, nf(2), activation=activation)
570
+ self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
571
+
572
+ def forward(self, images_in, masks_in, c):
573
+ x = torch.cat([masks_in - 0.5, images_in], dim=1)
574
+ x = self.Dis(x)
575
+ x = self.fc1(self.fc0(x.flatten(start_dim=1)))
576
+
577
+ if self.c_dim > 0:
578
+ cmap = self.mapping(None, c)
579
+
580
+ if self.cmap_dim > 0:
581
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
582
+
583
+ return x
networks/mat.py ADDED
@@ -0,0 +1,996 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import sys
4
+ sys.path.insert(0, '../')
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint as checkpoint
10
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
11
+
12
+ from torch_utils import misc
13
+ from torch_utils import persistence
14
+ from networks.basic_module import FullyConnectedLayer, Conv2dLayer, MappingNet, MinibatchStdLayer, DisFromRGB, DisBlock, StyleConv, ToRGB, get_style_code
15
+
16
+
17
+ @misc.profiled_function
18
+ def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
19
+ NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
20
+ return NF[2 ** stage]
21
+
22
+
23
+ @persistence.persistent_class
24
+ class Mlp(nn.Module):
25
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = FullyConnectedLayer(in_features=in_features, out_features=hidden_features, activation='lrelu')
30
+ self.fc2 = FullyConnectedLayer(in_features=hidden_features, out_features=out_features)
31
+
32
+ def forward(self, x):
33
+ x = self.fc1(x)
34
+ x = self.fc2(x)
35
+ return x
36
+
37
+
38
+ @misc.profiled_function
39
+ def window_partition(x, window_size):
40
+ """
41
+ Args:
42
+ x: (B, H, W, C)
43
+ window_size (int): window size
44
+ Returns:
45
+ windows: (num_windows*B, window_size, window_size, C)
46
+ """
47
+ B, H, W, C = x.shape
48
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
49
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
50
+ return windows
51
+
52
+
53
+ @misc.profiled_function
54
+ def window_reverse(windows, window_size, H, W):
55
+ """
56
+ Args:
57
+ windows: (num_windows*B, window_size, window_size, C)
58
+ window_size (int): Window size
59
+ H (int): Height of image
60
+ W (int): Width of image
61
+ Returns:
62
+ x: (B, H, W, C)
63
+ """
64
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
65
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
66
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
67
+ return x
68
+
69
+
70
+ @persistence.persistent_class
71
+ class Conv2dLayerPartial(nn.Module):
72
+ def __init__(self,
73
+ in_channels, # Number of input channels.
74
+ out_channels, # Number of output channels.
75
+ kernel_size, # Width and height of the convolution kernel.
76
+ bias = True, # Apply additive bias before the activation function?
77
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
78
+ up = 1, # Integer upsampling factor.
79
+ down = 1, # Integer downsampling factor.
80
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
81
+ conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
82
+ trainable = True, # Update the weights of this layer during training?
83
+ ):
84
+ super().__init__()
85
+ self.conv = Conv2dLayer(in_channels, out_channels, kernel_size, bias, activation, up, down, resample_filter,
86
+ conv_clamp, trainable)
87
+
88
+ self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
89
+ self.slide_winsize = kernel_size ** 2
90
+ self.stride = down
91
+ self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
92
+
93
+ def forward(self, x, mask=None):
94
+ if mask is not None:
95
+ with torch.no_grad():
96
+ if self.weight_maskUpdater.type() != x.type():
97
+ self.weight_maskUpdater = self.weight_maskUpdater.to(x)
98
+ update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding)
99
+ mask_ratio = self.slide_winsize / (update_mask + 1e-8)
100
+ update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1
101
+ mask_ratio = torch.mul(mask_ratio, update_mask)
102
+ x = self.conv(x)
103
+ x = torch.mul(x, mask_ratio)
104
+ return x, update_mask
105
+ else:
106
+ x = self.conv(x)
107
+ return x, None
108
+
109
+
110
+ @persistence.persistent_class
111
+ class WindowAttention(nn.Module):
112
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
113
+ It supports both of shifted and non-shifted window.
114
+ Args:
115
+ dim (int): Number of input channels.
116
+ window_size (tuple[int]): The height and width of the window.
117
+ num_heads (int): Number of attention heads.
118
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
119
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
120
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
121
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
122
+ """
123
+
124
+ def __init__(self, dim, window_size, num_heads, down_ratio=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
125
+
126
+ super().__init__()
127
+ self.dim = dim
128
+ self.window_size = window_size # Wh, Ww
129
+ self.num_heads = num_heads
130
+ head_dim = dim // num_heads
131
+ self.scale = qk_scale or head_dim ** -0.5
132
+
133
+ self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
134
+ self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
135
+ self.v = FullyConnectedLayer(in_features=dim, out_features=dim)
136
+ self.proj = FullyConnectedLayer(in_features=dim, out_features=dim)
137
+
138
+ self.softmax = nn.Softmax(dim=-1)
139
+
140
+ def forward(self, x, mask_windows=None, mask=None):
141
+ """
142
+ Args:
143
+ x: input features with shape of (num_windows*B, N, C)
144
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
145
+ """
146
+ B_, N, C = x.shape
147
+ norm_x = F.normalize(x, p=2.0, dim=-1)
148
+ q = self.q(norm_x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
149
+ k = self.k(norm_x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 3, 1)
150
+ v = self.v(x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
151
+
152
+ attn = (q @ k) * self.scale
153
+
154
+ if mask is not None:
155
+ nW = mask.shape[0]
156
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
157
+ attn = attn.view(-1, self.num_heads, N, N)
158
+
159
+ if mask_windows is not None:
160
+ attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1)
161
+ attn = attn + attn_mask_windows.masked_fill(attn_mask_windows == 0, float(-100.0)).masked_fill(
162
+ attn_mask_windows == 1, float(0.0))
163
+ with torch.no_grad():
164
+ mask_windows = torch.clamp(torch.sum(mask_windows, dim=1, keepdim=True), 0, 1).repeat(1, N, 1)
165
+
166
+ attn = self.softmax(attn)
167
+
168
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
169
+ x = self.proj(x)
170
+ return x, mask_windows
171
+
172
+
173
+ @persistence.persistent_class
174
+ class SwinTransformerBlock(nn.Module):
175
+ r""" Swin Transformer Block.
176
+ Args:
177
+ dim (int): Number of input channels.
178
+ input_resolution (tuple[int]): Input resulotion.
179
+ num_heads (int): Number of attention heads.
180
+ window_size (int): Window size.
181
+ shift_size (int): Shift size for SW-MSA.
182
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
183
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
184
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
185
+ drop (float, optional): Dropout rate. Default: 0.0
186
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
187
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
188
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
189
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
190
+ """
191
+
192
+ def __init__(self, dim, input_resolution, num_heads, down_ratio=1, window_size=7, shift_size=0,
193
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
194
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
195
+ super().__init__()
196
+ self.dim = dim
197
+ self.input_resolution = input_resolution
198
+ self.num_heads = num_heads
199
+ self.window_size = window_size
200
+ self.shift_size = shift_size
201
+ self.mlp_ratio = mlp_ratio
202
+ if min(self.input_resolution) <= self.window_size:
203
+ # if window size is larger than input resolution, we don't partition windows
204
+ self.shift_size = 0
205
+ self.window_size = min(self.input_resolution)
206
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
207
+
208
+ if self.shift_size > 0:
209
+ down_ratio = 1
210
+ self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
211
+ down_ratio=down_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
212
+ proj_drop=drop)
213
+
214
+ self.fuse = FullyConnectedLayer(in_features=dim * 2, out_features=dim, activation='lrelu')
215
+
216
+ mlp_hidden_dim = int(dim * mlp_ratio)
217
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
218
+
219
+ if self.shift_size > 0:
220
+ attn_mask = self.calculate_mask(self.input_resolution)
221
+ else:
222
+ attn_mask = None
223
+
224
+ self.register_buffer("attn_mask", attn_mask)
225
+
226
+ def calculate_mask(self, x_size):
227
+ # calculate attention mask for SW-MSA
228
+ H, W = x_size
229
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
230
+ h_slices = (slice(0, -self.window_size),
231
+ slice(-self.window_size, -self.shift_size),
232
+ slice(-self.shift_size, None))
233
+ w_slices = (slice(0, -self.window_size),
234
+ slice(-self.window_size, -self.shift_size),
235
+ slice(-self.shift_size, None))
236
+ cnt = 0
237
+ for h in h_slices:
238
+ for w in w_slices:
239
+ img_mask[:, h, w, :] = cnt
240
+ cnt += 1
241
+
242
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
243
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
244
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
245
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
246
+
247
+ return attn_mask
248
+
249
+ def forward(self, x, x_size, mask=None):
250
+ # H, W = self.input_resolution
251
+ H, W = x_size
252
+ B, L, C = x.shape
253
+ assert L == H * W, "input feature has wrong size"
254
+
255
+ shortcut = x
256
+ x = x.view(B, H, W, C)
257
+ if mask is not None:
258
+ mask = mask.view(B, H, W, 1)
259
+
260
+ # cyclic shift
261
+ if self.shift_size > 0:
262
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
263
+ if mask is not None:
264
+ shifted_mask = torch.roll(mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
265
+ else:
266
+ shifted_x = x
267
+ if mask is not None:
268
+ shifted_mask = mask
269
+
270
+ # partition windows
271
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
272
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
273
+ if mask is not None:
274
+ mask_windows = window_partition(shifted_mask, self.window_size)
275
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1)
276
+ else:
277
+ mask_windows = None
278
+
279
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
280
+ if self.input_resolution == x_size:
281
+ attn_windows, mask_windows = self.attn(x_windows, mask_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
282
+ else:
283
+ attn_windows, mask_windows = self.attn(x_windows, mask_windows, mask=self.calculate_mask(x_size).to(x.device)) # nW*B, window_size*window_size, C
284
+
285
+ # merge windows
286
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
287
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
288
+ if mask is not None:
289
+ mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1)
290
+ shifted_mask = window_reverse(mask_windows, self.window_size, H, W)
291
+
292
+ # reverse cyclic shift
293
+ if self.shift_size > 0:
294
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
295
+ if mask is not None:
296
+ mask = torch.roll(shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
297
+ else:
298
+ x = shifted_x
299
+ if mask is not None:
300
+ mask = shifted_mask
301
+ x = x.view(B, H * W, C)
302
+ if mask is not None:
303
+ mask = mask.view(B, H * W, 1)
304
+
305
+ # FFN
306
+ x = self.fuse(torch.cat([shortcut, x], dim=-1))
307
+ x = self.mlp(x)
308
+
309
+ return x, mask
310
+
311
+
312
+ @persistence.persistent_class
313
+ class PatchMerging(nn.Module):
314
+ def __init__(self, in_channels, out_channels, down=2):
315
+ super().__init__()
316
+ self.conv = Conv2dLayerPartial(in_channels=in_channels,
317
+ out_channels=out_channels,
318
+ kernel_size=3,
319
+ activation='lrelu',
320
+ down=down,
321
+ )
322
+ self.down = down
323
+
324
+ def forward(self, x, x_size, mask=None):
325
+ x = token2feature(x, x_size)
326
+ if mask is not None:
327
+ mask = token2feature(mask, x_size)
328
+ x, mask = self.conv(x, mask)
329
+ if self.down != 1:
330
+ ratio = 1 / self.down
331
+ x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio))
332
+ x = feature2token(x)
333
+ if mask is not None:
334
+ mask = feature2token(mask)
335
+ return x, x_size, mask
336
+
337
+
338
+ @persistence.persistent_class
339
+ class PatchUpsampling(nn.Module):
340
+ def __init__(self, in_channels, out_channels, up=2):
341
+ super().__init__()
342
+ self.conv = Conv2dLayerPartial(in_channels=in_channels,
343
+ out_channels=out_channels,
344
+ kernel_size=3,
345
+ activation='lrelu',
346
+ up=up,
347
+ )
348
+ self.up = up
349
+
350
+ def forward(self, x, x_size, mask=None):
351
+ x = token2feature(x, x_size)
352
+ if mask is not None:
353
+ mask = token2feature(mask, x_size)
354
+ x, mask = self.conv(x, mask)
355
+ if self.up != 1:
356
+ x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up))
357
+ x = feature2token(x)
358
+ if mask is not None:
359
+ mask = feature2token(mask)
360
+ return x, x_size, mask
361
+
362
+
363
+
364
+ @persistence.persistent_class
365
+ class BasicLayer(nn.Module):
366
+ """ A basic Swin Transformer layer for one stage.
367
+ Args:
368
+ dim (int): Number of input channels.
369
+ input_resolution (tuple[int]): Input resolution.
370
+ depth (int): Number of blocks.
371
+ num_heads (int): Number of attention heads.
372
+ window_size (int): Local window size.
373
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
374
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
375
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
376
+ drop (float, optional): Dropout rate. Default: 0.0
377
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
378
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
379
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
380
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
381
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
382
+ """
383
+
384
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, down_ratio=1,
385
+ mlp_ratio=2., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
386
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
387
+
388
+ super().__init__()
389
+ self.dim = dim
390
+ self.input_resolution = input_resolution
391
+ self.depth = depth
392
+ self.use_checkpoint = use_checkpoint
393
+
394
+ # patch merging layer
395
+ if downsample is not None:
396
+ # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
397
+ self.downsample = downsample
398
+ else:
399
+ self.downsample = None
400
+
401
+ # build blocks
402
+ self.blocks = nn.ModuleList([
403
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
404
+ num_heads=num_heads, down_ratio=down_ratio, window_size=window_size,
405
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
406
+ mlp_ratio=mlp_ratio,
407
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
408
+ drop=drop, attn_drop=attn_drop,
409
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
410
+ norm_layer=norm_layer)
411
+ for i in range(depth)])
412
+
413
+ self.conv = Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, activation='lrelu')
414
+
415
+ def forward(self, x, x_size, mask=None):
416
+ if self.downsample is not None:
417
+ x, x_size, mask = self.downsample(x, x_size, mask)
418
+ identity = x
419
+ for blk in self.blocks:
420
+ if self.use_checkpoint:
421
+ x, mask = checkpoint.checkpoint(blk, x, x_size, mask)
422
+ else:
423
+ x, mask = blk(x, x_size, mask)
424
+ if mask is not None:
425
+ mask = token2feature(mask, x_size)
426
+ x, mask = self.conv(token2feature(x, x_size), mask)
427
+ x = feature2token(x) + identity
428
+ if mask is not None:
429
+ mask = feature2token(mask)
430
+ return x, x_size, mask
431
+
432
+
433
+ @persistence.persistent_class
434
+ class ToToken(nn.Module):
435
+ def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1):
436
+ super().__init__()
437
+
438
+ self.proj = Conv2dLayerPartial(in_channels=in_channels, out_channels=dim, kernel_size=kernel_size, activation='lrelu')
439
+
440
+ def forward(self, x, mask):
441
+ x, mask = self.proj(x, mask)
442
+
443
+ return x, mask
444
+
445
+ #----------------------------------------------------------------------------
446
+
447
+ @persistence.persistent_class
448
+ class EncFromRGB(nn.Module):
449
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
450
+ super().__init__()
451
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
452
+ out_channels=out_channels,
453
+ kernel_size=1,
454
+ activation=activation,
455
+ )
456
+ self.conv1 = Conv2dLayer(in_channels=out_channels,
457
+ out_channels=out_channels,
458
+ kernel_size=3,
459
+ activation=activation,
460
+ )
461
+
462
+ def forward(self, x):
463
+ x = self.conv0(x)
464
+ x = self.conv1(x)
465
+
466
+ return x
467
+
468
+ @persistence.persistent_class
469
+ class ConvBlockDown(nn.Module):
470
+ def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log
471
+ super().__init__()
472
+
473
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
474
+ out_channels=out_channels,
475
+ kernel_size=3,
476
+ activation=activation,
477
+ down=2,
478
+ )
479
+ self.conv1 = Conv2dLayer(in_channels=out_channels,
480
+ out_channels=out_channels,
481
+ kernel_size=3,
482
+ activation=activation,
483
+ )
484
+
485
+ def forward(self, x):
486
+ x = self.conv0(x)
487
+ x = self.conv1(x)
488
+
489
+ return x
490
+
491
+
492
+ def token2feature(x, x_size):
493
+ B, N, C = x.shape
494
+ h, w = x_size
495
+ x = x.permute(0, 2, 1).reshape(B, C, h, w)
496
+ return x
497
+
498
+
499
+ def feature2token(x):
500
+ B, C, H, W = x.shape
501
+ x = x.view(B, C, -1).transpose(1, 2)
502
+ return x
503
+
504
+
505
+ @persistence.persistent_class
506
+ class Encoder(nn.Module):
507
+ def __init__(self, res_log2, img_channels, activation, patch_size=5, channels=16, drop_path_rate=0.1):
508
+ super().__init__()
509
+
510
+ self.resolution = []
511
+
512
+ for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16
513
+ res = 2 ** i
514
+ self.resolution.append(res)
515
+ if i == res_log2:
516
+ block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
517
+ else:
518
+ block = ConvBlockDown(nf(i+1), nf(i), activation)
519
+ setattr(self, 'EncConv_Block_%dx%d' % (res, res), block)
520
+
521
+ def forward(self, x):
522
+ out = {}
523
+ for res in self.resolution:
524
+ res_log2 = int(np.log2(res))
525
+ x = getattr(self, 'EncConv_Block_%dx%d' % (res, res))(x)
526
+ out[res_log2] = x
527
+
528
+ return out
529
+
530
+
531
+ @persistence.persistent_class
532
+ class ToStyle(nn.Module):
533
+ def __init__(self, in_channels, out_channels, activation, drop_rate):
534
+ super().__init__()
535
+ self.conv = nn.Sequential(
536
+ Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, down=2),
537
+ Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, down=2),
538
+ Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, down=2),
539
+ )
540
+
541
+ self.pool = nn.AdaptiveAvgPool2d(1)
542
+ self.fc = FullyConnectedLayer(in_features=in_channels,
543
+ out_features=out_channels,
544
+ activation=activation)
545
+ # self.dropout = nn.Dropout(drop_rate)
546
+
547
+ def forward(self, x):
548
+ x = self.conv(x)
549
+ x = self.pool(x)
550
+ x = self.fc(x.flatten(start_dim=1))
551
+ # x = self.dropout(x)
552
+
553
+ return x
554
+
555
+
556
+ @persistence.persistent_class
557
+ class DecBlockFirstV2(nn.Module):
558
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
559
+ super().__init__()
560
+ self.res = res
561
+
562
+ self.conv0 = Conv2dLayer(in_channels=in_channels,
563
+ out_channels=in_channels,
564
+ kernel_size=3,
565
+ activation=activation,
566
+ )
567
+ self.conv1 = StyleConv(in_channels=in_channels,
568
+ out_channels=out_channels,
569
+ style_dim=style_dim,
570
+ resolution=2**res,
571
+ kernel_size=3,
572
+ use_noise=use_noise,
573
+ activation=activation,
574
+ demodulate=demodulate,
575
+ )
576
+ self.toRGB = ToRGB(in_channels=out_channels,
577
+ out_channels=img_channels,
578
+ style_dim=style_dim,
579
+ kernel_size=1,
580
+ demodulate=False,
581
+ )
582
+
583
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
584
+ # x = self.fc(x).view(x.shape[0], -1, 4, 4)
585
+ x = self.conv0(x)
586
+ x = x + E_features[self.res]
587
+ style = get_style_code(ws[:, 0], gs)
588
+ x = self.conv1(x, style, noise_mode=noise_mode)
589
+ style = get_style_code(ws[:, 1], gs)
590
+ img = self.toRGB(x, style, skip=None)
591
+
592
+ return x, img
593
+
594
+ #----------------------------------------------------------------------------
595
+
596
+ @persistence.persistent_class
597
+ class DecBlock(nn.Module):
598
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): # res = 4, ..., resolution_log2
599
+ super().__init__()
600
+ self.res = res
601
+
602
+ self.conv0 = StyleConv(in_channels=in_channels,
603
+ out_channels=out_channels,
604
+ style_dim=style_dim,
605
+ resolution=2**res,
606
+ kernel_size=3,
607
+ up=2,
608
+ use_noise=use_noise,
609
+ activation=activation,
610
+ demodulate=demodulate,
611
+ )
612
+ self.conv1 = StyleConv(in_channels=out_channels,
613
+ out_channels=out_channels,
614
+ style_dim=style_dim,
615
+ resolution=2**res,
616
+ kernel_size=3,
617
+ use_noise=use_noise,
618
+ activation=activation,
619
+ demodulate=demodulate,
620
+ )
621
+ self.toRGB = ToRGB(in_channels=out_channels,
622
+ out_channels=img_channels,
623
+ style_dim=style_dim,
624
+ kernel_size=1,
625
+ demodulate=False,
626
+ )
627
+
628
+ def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
629
+ style = get_style_code(ws[:, self.res * 2 - 9], gs)
630
+ x = self.conv0(x, style, noise_mode=noise_mode)
631
+ x = x + E_features[self.res]
632
+ style = get_style_code(ws[:, self.res * 2 - 8], gs)
633
+ x = self.conv1(x, style, noise_mode=noise_mode)
634
+ style = get_style_code(ws[:, self.res * 2 - 7], gs)
635
+ img = self.toRGB(x, style, skip=img)
636
+
637
+ return x, img
638
+
639
+
640
+ @persistence.persistent_class
641
+ class Decoder(nn.Module):
642
+ def __init__(self, res_log2, activation, style_dim, use_noise, demodulate, img_channels):
643
+ super().__init__()
644
+ self.Dec_16x16 = DecBlockFirstV2(4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels)
645
+ for res in range(5, res_log2 + 1):
646
+ setattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res),
647
+ DecBlock(res, nf(res - 1), nf(res), activation, style_dim, use_noise, demodulate, img_channels))
648
+ self.res_log2 = res_log2
649
+
650
+ def forward(self, x, ws, gs, E_features, noise_mode='random'):
651
+ x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
652
+ for res in range(5, self.res_log2 + 1):
653
+ block = getattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res))
654
+ x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
655
+
656
+ return img
657
+
658
+
659
+ @persistence.persistent_class
660
+ class DecStyleBlock(nn.Module):
661
+ def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
662
+ super().__init__()
663
+ self.res = res
664
+
665
+ self.conv0 = StyleConv(in_channels=in_channels,
666
+ out_channels=out_channels,
667
+ style_dim=style_dim,
668
+ resolution=2**res,
669
+ kernel_size=3,
670
+ up=2,
671
+ use_noise=use_noise,
672
+ activation=activation,
673
+ demodulate=demodulate,
674
+ )
675
+ self.conv1 = StyleConv(in_channels=out_channels,
676
+ out_channels=out_channels,
677
+ style_dim=style_dim,
678
+ resolution=2**res,
679
+ kernel_size=3,
680
+ use_noise=use_noise,
681
+ activation=activation,
682
+ demodulate=demodulate,
683
+ )
684
+ self.toRGB = ToRGB(in_channels=out_channels,
685
+ out_channels=img_channels,
686
+ style_dim=style_dim,
687
+ kernel_size=1,
688
+ demodulate=False,
689
+ )
690
+
691
+ def forward(self, x, img, style, skip, noise_mode='random'):
692
+ x = self.conv0(x, style, noise_mode=noise_mode)
693
+ x = x + skip
694
+ x = self.conv1(x, style, noise_mode=noise_mode)
695
+ img = self.toRGB(x, style, skip=img)
696
+
697
+ return x, img
698
+
699
+
700
+ @persistence.persistent_class
701
+ class FirstStage(nn.Module):
702
+ def __init__(self, img_channels, img_resolution=256, dim=180, w_dim=512, use_noise=False, demodulate=True, activation='lrelu'):
703
+ super().__init__()
704
+ res = 64
705
+
706
+ self.conv_first = Conv2dLayerPartial(in_channels=img_channels+1, out_channels=dim, kernel_size=3, activation=activation)
707
+ self.enc_conv = nn.ModuleList()
708
+ down_time = int(np.log2(img_resolution // res))
709
+ for i in range(down_time): # from input size to 64
710
+ self.enc_conv.append(
711
+ Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation)
712
+ )
713
+
714
+ # from 64 -> 16 -> 64
715
+ depths = [2, 3, 4, 3, 2]
716
+ ratios = [1, 1/2, 1/2, 2, 2]
717
+ num_heads = 6
718
+ window_sizes = [8, 16, 16, 16, 8]
719
+ drop_path_rate = 0.1
720
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
721
+
722
+ self.tran = nn.ModuleList()
723
+ for i, depth in enumerate(depths):
724
+ res = int(res * ratios[i])
725
+ if ratios[i] < 1:
726
+ merge = PatchMerging(dim, dim, down=int(1/ratios[i]))
727
+ elif ratios[i] > 1:
728
+ merge = PatchUpsampling(dim, dim, up=ratios[i])
729
+ else:
730
+ merge = None
731
+ self.tran.append(
732
+ BasicLayer(dim=dim, input_resolution=[res, res], depth=depth, num_heads=num_heads,
733
+ window_size=window_sizes[i], drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
734
+ downsample=merge)
735
+ )
736
+
737
+ # global style
738
+ down_conv = []
739
+ for i in range(int(np.log2(16))):
740
+ down_conv.append(Conv2dLayer(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation))
741
+ down_conv.append(nn.AdaptiveAvgPool2d((1, 1)))
742
+ self.down_conv = nn.Sequential(*down_conv)
743
+ self.to_style = FullyConnectedLayer(in_features=dim, out_features=dim*2, activation=activation)
744
+ self.ws_style = FullyConnectedLayer(in_features=w_dim, out_features=dim, activation=activation)
745
+ self.to_square = FullyConnectedLayer(in_features=dim, out_features=16*16, activation=activation)
746
+
747
+ style_dim = dim * 3
748
+ self.dec_conv = nn.ModuleList()
749
+ for i in range(down_time): # from 64 to input size
750
+ res = res * 2
751
+ self.dec_conv.append(DecStyleBlock(res, dim, dim, activation, style_dim, use_noise, demodulate, img_channels))
752
+
753
+ def forward(self, images_in, masks_in, ws, noise_mode='random'):
754
+ x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1)
755
+
756
+ skips = []
757
+ x, mask = self.conv_first(x, masks_in) # input size
758
+ skips.append(x)
759
+ for i, block in enumerate(self.enc_conv): # input size to 64
760
+ x, mask = block(x, mask)
761
+ if i != len(self.enc_conv) - 1:
762
+ skips.append(x)
763
+
764
+ x_size = x.size()[-2:]
765
+ x = feature2token(x)
766
+ mask = feature2token(mask)
767
+ mid = len(self.tran) // 2
768
+ for i, block in enumerate(self.tran): # 64 to 16
769
+ if i < mid:
770
+ x, x_size, mask = block(x, x_size, mask)
771
+ skips.append(x)
772
+ elif i > mid:
773
+ x, x_size, mask = block(x, x_size, None)
774
+ x = x + skips[mid - i]
775
+ else:
776
+ x, x_size, mask = block(x, x_size, None)
777
+
778
+ mul_map = torch.ones_like(x) * 0.5
779
+ mul_map = F.dropout(mul_map, training=True)
780
+ ws = self.ws_style(ws[:, -1])
781
+ add_n = self.to_square(ws).unsqueeze(1)
782
+ add_n = F.interpolate(add_n, size=x.size(1), mode='linear', align_corners=False).squeeze(1).unsqueeze(-1)
783
+ x = x * mul_map + add_n * (1 - mul_map)
784
+ gs = self.to_style(self.down_conv(token2feature(x, x_size)).flatten(start_dim=1))
785
+ style = torch.cat([gs, ws], dim=1)
786
+
787
+ x = token2feature(x, x_size).contiguous()
788
+ img = None
789
+ for i, block in enumerate(self.dec_conv):
790
+ x, img = block(x, img, style, skips[len(self.dec_conv)-i-1], noise_mode=noise_mode)
791
+
792
+ # ensemble
793
+ img = img * (1 - masks_in) + images_in * masks_in
794
+
795
+ return img
796
+
797
+
798
+ @persistence.persistent_class
799
+ class SynthesisNet(nn.Module):
800
+ def __init__(self,
801
+ w_dim, # Intermediate latent (W) dimensionality.
802
+ img_resolution, # Output image resolution.
803
+ img_channels = 3, # Number of color channels.
804
+ channel_base = 32768, # Overall multiplier for the number of channels.
805
+ channel_decay = 1.0,
806
+ channel_max = 512, # Maximum number of channels in any layer.
807
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
808
+ drop_rate = 0.5,
809
+ use_noise = True,
810
+ demodulate = True,
811
+ ):
812
+ super().__init__()
813
+ resolution_log2 = int(np.log2(img_resolution))
814
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
815
+
816
+ self.num_layers = resolution_log2 * 2 - 3 * 2
817
+ self.img_resolution = img_resolution
818
+ self.resolution_log2 = resolution_log2
819
+
820
+ # first stage
821
+ self.first_stage = FirstStage(img_channels, img_resolution=img_resolution, w_dim=w_dim, use_noise=False, demodulate=demodulate)
822
+
823
+ # second stage
824
+ self.enc = Encoder(resolution_log2, img_channels, activation, patch_size=5, channels=16)
825
+ self.to_square = FullyConnectedLayer(in_features=w_dim, out_features=16*16, activation=activation)
826
+ self.to_style = ToStyle(in_channels=nf(4), out_channels=nf(2) * 2, activation=activation, drop_rate=drop_rate)
827
+ style_dim = w_dim + nf(2) * 2
828
+ self.dec = Decoder(resolution_log2, activation, style_dim, use_noise, demodulate, img_channels)
829
+
830
+ def forward(self, images_in, masks_in, ws, noise_mode='random', return_stg1=False):
831
+ out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode)
832
+
833
+ # encoder
834
+ x = images_in * masks_in + out_stg1 * (1 - masks_in)
835
+ x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1)
836
+ E_features = self.enc(x)
837
+
838
+ fea_16 = E_features[4]
839
+ mul_map = torch.ones_like(fea_16) * 0.5
840
+ mul_map = F.dropout(mul_map, training=True)
841
+ add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1)
842
+ add_n = F.interpolate(add_n, size=fea_16.size()[-2:], mode='bilinear', align_corners=False)
843
+ fea_16 = fea_16 * mul_map + add_n * (1 - mul_map)
844
+ E_features[4] = fea_16
845
+
846
+ # style
847
+ gs = self.to_style(fea_16)
848
+
849
+ # decoder
850
+ img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode)
851
+
852
+ # ensemble
853
+ img = img * (1 - masks_in) + images_in * masks_in
854
+
855
+ if not return_stg1:
856
+ return img
857
+ else:
858
+ return img, out_stg1
859
+
860
+
861
+ @persistence.persistent_class
862
+ class Generator(nn.Module):
863
+ def __init__(self,
864
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
865
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
866
+ w_dim, # Intermediate latent (W) dimensionality.
867
+ img_resolution, # resolution of generated image
868
+ img_channels, # Number of input color channels.
869
+ synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
870
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
871
+ ):
872
+ super().__init__()
873
+ self.z_dim = z_dim
874
+ self.c_dim = c_dim
875
+ self.w_dim = w_dim
876
+ self.img_resolution = img_resolution
877
+ self.img_channels = img_channels
878
+
879
+ self.synthesis = SynthesisNet(w_dim=w_dim,
880
+ img_resolution=img_resolution,
881
+ img_channels=img_channels,
882
+ **synthesis_kwargs)
883
+ self.mapping = MappingNet(z_dim=z_dim,
884
+ c_dim=c_dim,
885
+ w_dim=w_dim,
886
+ num_ws=self.synthesis.num_layers,
887
+ **mapping_kwargs)
888
+
889
+ def forward(self, images_in, masks_in, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False,
890
+ noise_mode='random', return_stg1=False):
891
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff,
892
+ skip_w_avg_update=skip_w_avg_update)
893
+
894
+ if not return_stg1:
895
+ img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode)
896
+ return img
897
+ else:
898
+ img, out_stg1 = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode, return_stg1=True)
899
+ return img, out_stg1
900
+
901
+
902
+ @persistence.persistent_class
903
+ class Discriminator(torch.nn.Module):
904
+ def __init__(self,
905
+ c_dim, # Conditioning label (C) dimensionality.
906
+ img_resolution, # Input resolution.
907
+ img_channels, # Number of input color channels.
908
+ channel_base = 32768, # Overall multiplier for the number of channels.
909
+ channel_max = 512, # Maximum number of channels in any layer.
910
+ channel_decay = 1,
911
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
912
+ activation = 'lrelu',
913
+ mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
914
+ mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
915
+ ):
916
+ super().__init__()
917
+ self.c_dim = c_dim
918
+ self.img_resolution = img_resolution
919
+ self.img_channels = img_channels
920
+
921
+ resolution_log2 = int(np.log2(img_resolution))
922
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
923
+ self.resolution_log2 = resolution_log2
924
+
925
+ if cmap_dim == None:
926
+ cmap_dim = nf(2)
927
+ if c_dim == 0:
928
+ cmap_dim = 0
929
+ self.cmap_dim = cmap_dim
930
+
931
+ if c_dim > 0:
932
+ self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
933
+
934
+ Dis = [DisFromRGB(img_channels+1, nf(resolution_log2), activation)]
935
+ for res in range(resolution_log2, 2, -1):
936
+ Dis.append(DisBlock(nf(res), nf(res-1), activation))
937
+
938
+ if mbstd_num_channels > 0:
939
+ Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
940
+ Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
941
+ self.Dis = nn.Sequential(*Dis)
942
+
943
+ self.fc0 = FullyConnectedLayer(nf(2)*4**2, nf(2), activation=activation)
944
+ self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
945
+
946
+ # for 64x64
947
+ Dis_stg1 = [DisFromRGB(img_channels+1, nf(resolution_log2) // 2, activation)]
948
+ for res in range(resolution_log2, 2, -1):
949
+ Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation))
950
+
951
+ if mbstd_num_channels > 0:
952
+ Dis_stg1.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
953
+ Dis_stg1.append(Conv2dLayer(nf(2) // 2 + mbstd_num_channels, nf(2) // 2, kernel_size=3, activation=activation))
954
+ self.Dis_stg1 = nn.Sequential(*Dis_stg1)
955
+
956
+ self.fc0_stg1 = FullyConnectedLayer(nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation)
957
+ self.fc1_stg1 = FullyConnectedLayer(nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim)
958
+
959
+ def forward(self, images_in, masks_in, images_stg1, c):
960
+ x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1))
961
+ x = self.fc1(self.fc0(x.flatten(start_dim=1)))
962
+
963
+ x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1))
964
+ x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1)))
965
+
966
+ if self.c_dim > 0:
967
+ cmap = self.mapping(None, c)
968
+
969
+ if self.cmap_dim > 0:
970
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
971
+ x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
972
+
973
+ return x, x_stg1
974
+
975
+
976
+ if __name__ == '__main__':
977
+ device = torch.device('cuda:0')
978
+ batch = 1
979
+ res = 512
980
+ G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3).to(device)
981
+ D = Discriminator(c_dim=0, img_resolution=res, img_channels=3).to(device)
982
+ img = torch.randn(batch, 3, res, res).to(device)
983
+ mask = torch.randn(batch, 1, res, res).to(device)
984
+ z = torch.randn(batch, 512).to(device)
985
+ G.eval()
986
+
987
+ # def count(block):
988
+ # return sum(p.numel() for p in block.parameters()) / 10 ** 6
989
+ # print('Generator', count(G))
990
+ # print('discriminator', count(D))
991
+
992
+ with torch.no_grad():
993
+ img, img_stg1 = G(img, mask, z, None, return_stg1=True)
994
+ print('output of G:', img.shape, img_stg1.shape)
995
+ score, score_stg1 = D(img, mask, img_stg1, None)
996
+ print('output of D:', score.shape, score_stg1.shape)
op.gif ADDED

Git LFS Details

  • SHA256: 2f046c9635d86f7856a4038925b1ecafcccd8113401da4f6883ef4d97a708430
  • Pointer size: 132 Bytes
  • Size of remote file: 6.57 MB
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ easydict
2
+ future
3
+ matplotlib
4
+ numpy
5
+ opencv-python
6
+ scikit-image
7
+ scipy
8
+ click
9
+ requests
10
+ tqdm
11
+ pyspng
12
+ ninja
13
+ imageio-ffmpeg==0.4.3
14
+ timm
15
+ psutil
16
+ scikit-learn
test_sets/CelebA-HQ/images/test1.png ADDED
test_sets/CelebA-HQ/images/test2.png ADDED
test_sets/CelebA-HQ/masks/mask1.png ADDED
test_sets/CelebA-HQ/masks/mask2.png ADDED
test_sets/Places/images/test1.jpg ADDED
test_sets/Places/images/test2.jpg ADDED
test_sets/Places/masks/mask1.png ADDED
test_sets/Places/masks/mask2.png ADDED
torch_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
torch_utils/custom_ops.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import glob
11
+ import torch
12
+ import torch.utils.cpp_extension
13
+ import importlib
14
+ import hashlib
15
+ import shutil
16
+ from pathlib import Path
17
+
18
+ from torch.utils.file_baton import FileBaton
19
+
20
+ #----------------------------------------------------------------------------
21
+ # Global options.
22
+
23
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24
+
25
+ #----------------------------------------------------------------------------
26
+ # Internal helper funcs.
27
+
28
+ def _find_compiler_bindir():
29
+ patterns = [
30
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34
+ ]
35
+ for pattern in patterns:
36
+ matches = sorted(glob.glob(pattern))
37
+ if len(matches):
38
+ return matches[-1]
39
+ return None
40
+
41
+ #----------------------------------------------------------------------------
42
+ # Main entry point for compiling and loading C++/CUDA plugins.
43
+
44
+ _cached_plugins = dict()
45
+
46
+ def get_plugin(module_name, sources, **build_kwargs):
47
+ assert verbosity in ['none', 'brief', 'full']
48
+
49
+ # Already cached?
50
+ if module_name in _cached_plugins:
51
+ return _cached_plugins[module_name]
52
+
53
+ # Print status.
54
+ if verbosity == 'full':
55
+ print(f'Setting up PyTorch plugin "{module_name}"...')
56
+ elif verbosity == 'brief':
57
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58
+
59
+ try: # pylint: disable=too-many-nested-blocks
60
+ # Make sure we can find the necessary compiler binaries.
61
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62
+ compiler_bindir = _find_compiler_bindir()
63
+ if compiler_bindir is None:
64
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65
+ os.environ['PATH'] += ';' + compiler_bindir
66
+
67
+ # Compile and load.
68
+ verbose_build = (verbosity == 'full')
69
+
70
+ # Incremental build md5sum trickery. Copies all the input source files
71
+ # into a cached build directory under a combined md5 digest of the input
72
+ # source files. Copying is done only if the combined digest has changed.
73
+ # This keeps input file timestamps and filenames the same as in previous
74
+ # extension builds, allowing for fast incremental rebuilds.
75
+ #
76
+ # This optimization is done only in case all the source files reside in
77
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78
+ # environment variable is set (we take this as a signal that the user
79
+ # actually cares about this.)
80
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
81
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83
+
84
+ # Compute a combined hash digest for all source files in the same
85
+ # custom op directory (usually .cu, .cpp, .py and .h files).
86
+ hash_md5 = hashlib.md5()
87
+ for src in all_source_files:
88
+ with open(src, 'rb') as f:
89
+ hash_md5.update(f.read())
90
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92
+
93
+ if not os.path.isdir(digest_build_dir):
94
+ os.makedirs(digest_build_dir, exist_ok=True)
95
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96
+ if baton.try_acquire():
97
+ try:
98
+ for src in all_source_files:
99
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100
+ finally:
101
+ baton.release()
102
+ else:
103
+ # Someone else is copying source files under the digest dir,
104
+ # wait until done and continue.
105
+ baton.wait()
106
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
109
+ else:
110
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111
+ module = importlib.import_module(module_name)
112
+
113
+ except:
114
+ if verbosity == 'brief':
115
+ print('Failed!')
116
+ raise
117
+
118
+ # Print status and add to cache.
119
+ if verbosity == 'full':
120
+ print(f'Done setting up PyTorch plugin "{module_name}".')
121
+ elif verbosity == 'brief':
122
+ print('Done.')
123
+ _cached_plugins[module_name] = module
124
+ return module
125
+
126
+ #----------------------------------------------------------------------------
torch_utils/misc.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import re
10
+ import contextlib
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+ import dnnlib
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18
+ # same constant is used multiple times.
19
+
20
+ _constant_cache = dict()
21
+
22
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23
+ value = np.asarray(value)
24
+ if shape is not None:
25
+ shape = tuple(shape)
26
+ if dtype is None:
27
+ dtype = torch.get_default_dtype()
28
+ if device is None:
29
+ device = torch.device('cpu')
30
+ if memory_format is None:
31
+ memory_format = torch.contiguous_format
32
+
33
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34
+ tensor = _constant_cache.get(key, None)
35
+ if tensor is None:
36
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37
+ if shape is not None:
38
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39
+ tensor = tensor.contiguous(memory_format=memory_format)
40
+ _constant_cache[key] = tensor
41
+ return tensor
42
+
43
+ #----------------------------------------------------------------------------
44
+ # Replace NaN/Inf with specified numerical values.
45
+
46
+ try:
47
+ nan_to_num = torch.nan_to_num # 1.8.0a0
48
+ except AttributeError:
49
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50
+ assert isinstance(input, torch.Tensor)
51
+ if posinf is None:
52
+ posinf = torch.finfo(input.dtype).max
53
+ if neginf is None:
54
+ neginf = torch.finfo(input.dtype).min
55
+ assert nan == 0
56
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57
+
58
+ #----------------------------------------------------------------------------
59
+ # Symbolic assert.
60
+
61
+ try:
62
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63
+ except AttributeError:
64
+ symbolic_assert = torch.Assert # 1.7.0
65
+
66
+ #----------------------------------------------------------------------------
67
+ # Context manager to suppress known warnings in torch.jit.trace().
68
+
69
+ class suppress_tracer_warnings(warnings.catch_warnings):
70
+ def __enter__(self):
71
+ super().__enter__()
72
+ warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
73
+ return self
74
+
75
+ #----------------------------------------------------------------------------
76
+ # Assert that the shape of a tensor matches the given list of integers.
77
+ # None indicates that the size of a dimension is allowed to vary.
78
+ # Performs symbolic assertion when used in torch.jit.trace().
79
+
80
+ def assert_shape(tensor, ref_shape):
81
+ if tensor.ndim != len(ref_shape):
82
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
83
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
84
+ if ref_size is None:
85
+ pass
86
+ elif isinstance(ref_size, torch.Tensor):
87
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
88
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
89
+ elif isinstance(size, torch.Tensor):
90
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
91
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
92
+ elif size != ref_size:
93
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
94
+
95
+ #----------------------------------------------------------------------------
96
+ # Function decorator that calls torch.autograd.profiler.record_function().
97
+
98
+ def profiled_function(fn):
99
+ def decorator(*args, **kwargs):
100
+ with torch.autograd.profiler.record_function(fn.__name__):
101
+ return fn(*args, **kwargs)
102
+ decorator.__name__ = fn.__name__
103
+ return decorator
104
+
105
+ #----------------------------------------------------------------------------
106
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
107
+ # indefinitely, shuffling items as it goes.
108
+
109
+ class InfiniteSampler(torch.utils.data.Sampler):
110
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
111
+ assert len(dataset) > 0
112
+ assert num_replicas > 0
113
+ assert 0 <= rank < num_replicas
114
+ assert 0 <= window_size <= 1
115
+ super().__init__(dataset)
116
+ self.dataset = dataset
117
+ self.rank = rank
118
+ self.num_replicas = num_replicas
119
+ self.shuffle = shuffle
120
+ self.seed = seed
121
+ self.window_size = window_size
122
+
123
+ def __iter__(self):
124
+ order = np.arange(len(self.dataset))
125
+ rnd = None
126
+ window = 0
127
+ if self.shuffle:
128
+ rnd = np.random.RandomState(self.seed)
129
+ rnd.shuffle(order)
130
+ window = int(np.rint(order.size * self.window_size))
131
+
132
+ idx = 0
133
+ while True:
134
+ i = idx % order.size
135
+ if idx % self.num_replicas == self.rank:
136
+ yield order[i]
137
+ if window >= 2:
138
+ j = (i - rnd.randint(window)) % order.size
139
+ order[i], order[j] = order[j], order[i]
140
+ idx += 1
141
+
142
+ #----------------------------------------------------------------------------
143
+ # Utilities for operating with torch.nn.Module parameters and buffers.
144
+
145
+ def params_and_buffers(module):
146
+ assert isinstance(module, torch.nn.Module)
147
+ return list(module.parameters()) + list(module.buffers())
148
+
149
+ def named_params_and_buffers(module):
150
+ assert isinstance(module, torch.nn.Module)
151
+ return list(module.named_parameters()) + list(module.named_buffers())
152
+
153
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
154
+ assert isinstance(src_module, torch.nn.Module)
155
+ assert isinstance(dst_module, torch.nn.Module)
156
+ src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
157
+ for name, tensor in named_params_and_buffers(dst_module):
158
+ assert (name in src_tensors) or (not require_all)
159
+ if name in src_tensors:
160
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
161
+
162
+ #----------------------------------------------------------------------------
163
+ # Context manager for easily enabling/disabling DistributedDataParallel
164
+ # synchronization.
165
+
166
+ @contextlib.contextmanager
167
+ def ddp_sync(module, sync):
168
+ assert isinstance(module, torch.nn.Module)
169
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
170
+ yield
171
+ else:
172
+ with module.no_sync():
173
+ yield
174
+
175
+ #----------------------------------------------------------------------------
176
+ # Check DistributedDataParallel consistency across processes.
177
+
178
+ def check_ddp_consistency(module, ignore_regex=None):
179
+ assert isinstance(module, torch.nn.Module)
180
+ for name, tensor in named_params_and_buffers(module):
181
+ fullname = type(module).__name__ + '.' + name
182
+ flag = False
183
+ if ignore_regex is not None:
184
+ for regex in ignore_regex:
185
+ if re.fullmatch(regex, fullname):
186
+ flag = True
187
+ break
188
+ if flag:
189
+ continue
190
+ tensor = tensor.detach()
191
+ other = tensor.clone()
192
+ torch.distributed.broadcast(tensor=other, src=0)
193
+ assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
194
+
195
+ #----------------------------------------------------------------------------
196
+ # Print summary table of module hierarchy.
197
+
198
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
199
+ assert isinstance(module, torch.nn.Module)
200
+ assert not isinstance(module, torch.jit.ScriptModule)
201
+ assert isinstance(inputs, (tuple, list))
202
+
203
+ # Register hooks.
204
+ entries = []
205
+ nesting = [0]
206
+ def pre_hook(_mod, _inputs):
207
+ nesting[0] += 1
208
+ def post_hook(mod, _inputs, outputs):
209
+ nesting[0] -= 1
210
+ if nesting[0] <= max_nesting:
211
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
212
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
213
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
214
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
215
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
216
+
217
+ # Run module.
218
+ outputs = module(*inputs)
219
+ for hook in hooks:
220
+ hook.remove()
221
+
222
+ # Identify unique outputs, parameters, and buffers.
223
+ tensors_seen = set()
224
+ for e in entries:
225
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
226
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
227
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
228
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
229
+
230
+ # Filter out redundant entries.
231
+ if skip_redundant:
232
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
233
+
234
+ # Construct table.
235
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
236
+ rows += [['---'] * len(rows[0])]
237
+ param_total = 0
238
+ buffer_total = 0
239
+ submodule_names = {mod: name for name, mod in module.named_modules()}
240
+ for e in entries:
241
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
242
+ param_size = sum(t.numel() for t in e.unique_params)
243
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
244
+ output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
245
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
246
+ rows += [[
247
+ name + (':0' if len(e.outputs) >= 2 else ''),
248
+ str(param_size) if param_size else '-',
249
+ str(buffer_size) if buffer_size else '-',
250
+ (output_shapes + ['-'])[0],
251
+ (output_dtypes + ['-'])[0],
252
+ ]]
253
+ for idx in range(1, len(e.outputs)):
254
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
255
+ param_total += param_size
256
+ buffer_total += buffer_size
257
+ rows += [['---'] * len(rows[0])]
258
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
259
+
260
+ # Print table.
261
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
262
+ print()
263
+ for row in rows:
264
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
265
+ print()
266
+ return outputs
267
+
268
+ #----------------------------------------------------------------------------
torch_utils/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty