Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	dnnlib
Browse files- dnnlib/__init__.py +11 -0
- dnnlib/tflib/__init__.py +20 -0
- dnnlib/tflib/autosummary.py +193 -0
- dnnlib/tflib/custom_ops.py +171 -0
- dnnlib/tflib/network.py +592 -0
- dnnlib/tflib/ops/__init__.py +9 -0
- dnnlib/tflib/ops/fused_bias_act.cu +190 -0
- dnnlib/tflib/ops/fused_bias_act.py +198 -0
- dnnlib/tflib/ops/upfirdn_2d.cu +328 -0
- dnnlib/tflib/ops/upfirdn_2d.py +366 -0
- dnnlib/tflib/optimizer.py +338 -0
- dnnlib/tflib/tfutil.py +254 -0
- dnnlib/util.py +479 -0
    	
        dnnlib/__init__.py
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         | 
| 6 | 
            +
            # and proprietary rights in and to this software, related documentation
         | 
| 7 | 
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         | 
| 8 | 
            +
            # distribution of this software and related documentation without an express
         | 
| 9 | 
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from .util import EasyDict, make_cache_dir_path
         | 
    	
        dnnlib/tflib/__init__.py
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # This work is made available under the Nvidia Source Code License-NC.
         | 
| 6 | 
            +
            # To view a copy of this license, visit
         | 
| 7 | 
            +
            # https://nvlabs.github.io/stylegan2/license.html
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from . import autosummary
         | 
| 10 | 
            +
            from . import network
         | 
| 11 | 
            +
            from . import optimizer
         | 
| 12 | 
            +
            from . import tfutil
         | 
| 13 | 
            +
            from . import custom_ops
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from .tfutil import *
         | 
| 16 | 
            +
            from .network import Network
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .optimizer import Optimizer
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from .custom_ops import get_plugin
         | 
    	
        dnnlib/tflib/autosummary.py
    ADDED
    
    | @@ -0,0 +1,193 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # This work is made available under the Nvidia Source Code License-NC.
         | 
| 6 | 
            +
            # To view a copy of this license, visit
         | 
| 7 | 
            +
            # https://nvlabs.github.io/stylegan2/license.html
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """Helper for adding automatically tracked values to Tensorboard.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            Autosummary creates an identity op that internally keeps track of the input
         | 
| 12 | 
            +
            values and automatically shows up in TensorBoard. The reported value
         | 
| 13 | 
            +
            represents an average over input components. The average is accumulated
         | 
| 14 | 
            +
            constantly over time and flushed when save_summaries() is called.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            Notes:
         | 
| 17 | 
            +
            - The output tensor must be used as an input for something else in the
         | 
| 18 | 
            +
              graph. Otherwise, the autosummary op will not get executed, and the average
         | 
| 19 | 
            +
              value will not get accumulated.
         | 
| 20 | 
            +
            - It is perfectly fine to include autosummaries with the same name in
         | 
| 21 | 
            +
              several places throughout the graph, even if they are executed concurrently.
         | 
| 22 | 
            +
            - It is ok to also pass in a python scalar or numpy array. In this case, it
         | 
| 23 | 
            +
              is added to the average immediately.
         | 
| 24 | 
            +
            """
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from collections import OrderedDict
         | 
| 27 | 
            +
            import numpy as np
         | 
| 28 | 
            +
            import tensorflow as tf
         | 
| 29 | 
            +
            from tensorboard import summary as summary_lib
         | 
| 30 | 
            +
            from tensorboard.plugins.custom_scalar import layout_pb2
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            from . import tfutil
         | 
| 33 | 
            +
            from .tfutil import TfExpression
         | 
| 34 | 
            +
            from .tfutil import TfExpressionEx
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # Enable "Custom scalars" tab in TensorBoard for advanced formatting.
         | 
| 37 | 
            +
            # Disabled by default to reduce tfevents file size.
         | 
| 38 | 
            +
            enable_custom_scalars = False
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            _dtype = tf.float64
         | 
| 41 | 
            +
            _vars = OrderedDict()  # name => [var, ...]
         | 
| 42 | 
            +
            _immediate = OrderedDict()  # name => update_op, update_value
         | 
| 43 | 
            +
            _finalized = False
         | 
| 44 | 
            +
            _merge_op = None
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
         | 
| 48 | 
            +
                """Internal helper for creating autosummary accumulators."""
         | 
| 49 | 
            +
                assert not _finalized
         | 
| 50 | 
            +
                name_id = name.replace("/", "_")
         | 
| 51 | 
            +
                v = tf.cast(value_expr, _dtype)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                if v.shape.is_fully_defined():
         | 
| 54 | 
            +
                    size = np.prod(v.shape.as_list())
         | 
| 55 | 
            +
                    size_expr = tf.constant(size, dtype=_dtype)
         | 
| 56 | 
            +
                else:
         | 
| 57 | 
            +
                    size = None
         | 
| 58 | 
            +
                    size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                if size == 1:
         | 
| 61 | 
            +
                    if v.shape.ndims != 0:
         | 
| 62 | 
            +
                        v = tf.reshape(v, [])
         | 
| 63 | 
            +
                    v = [size_expr, v, tf.square(v)]
         | 
| 64 | 
            +
                else:
         | 
| 65 | 
            +
                    v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
         | 
| 66 | 
            +
                v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
         | 
| 69 | 
            +
                    var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False)  # [sum(1), sum(x), sum(x**2)]
         | 
| 70 | 
            +
                update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                if name in _vars:
         | 
| 73 | 
            +
                    _vars[name].append(var)
         | 
| 74 | 
            +
                else:
         | 
| 75 | 
            +
                    _vars[name] = [var]
         | 
| 76 | 
            +
                return update_op
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
         | 
| 80 | 
            +
                """Create a new autosummary.
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                Args:
         | 
| 83 | 
            +
                    name:     Name to use in TensorBoard
         | 
| 84 | 
            +
                    value:    TensorFlow expression or python value to track
         | 
| 85 | 
            +
                    passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                Example use of the passthru mechanism:
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                n = autosummary('l2loss', loss, passthru=n)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                This is a shorthand for the following code:
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                with tf.control_dependencies([autosummary('l2loss', loss)]):
         | 
| 94 | 
            +
                    n = tf.identity(n)
         | 
| 95 | 
            +
                """
         | 
| 96 | 
            +
                tfutil.assert_tf_initialized()
         | 
| 97 | 
            +
                name_id = name.replace("/", "_")
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                if tfutil.is_tf_expression(value):
         | 
| 100 | 
            +
                    with tf.name_scope("summary_" + name_id), tf.device(value.device):
         | 
| 101 | 
            +
                        condition = tf.convert_to_tensor(condition, name='condition')
         | 
| 102 | 
            +
                        update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
         | 
| 103 | 
            +
                        with tf.control_dependencies([update_op]):
         | 
| 104 | 
            +
                            return tf.identity(value if passthru is None else passthru)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                else:  # python scalar or numpy array
         | 
| 107 | 
            +
                    assert not tfutil.is_tf_expression(passthru)
         | 
| 108 | 
            +
                    assert not tfutil.is_tf_expression(condition)
         | 
| 109 | 
            +
                    if condition:
         | 
| 110 | 
            +
                        if name not in _immediate:
         | 
| 111 | 
            +
                            with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
         | 
| 112 | 
            +
                                update_value = tf.placeholder(_dtype)
         | 
| 113 | 
            +
                                update_op = _create_var(name, update_value)
         | 
| 114 | 
            +
                                _immediate[name] = update_op, update_value
         | 
| 115 | 
            +
                        update_op, update_value = _immediate[name]
         | 
| 116 | 
            +
                        tfutil.run(update_op, {update_value: value})
         | 
| 117 | 
            +
                    return value if passthru is None else passthru
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            def finalize_autosummaries() -> None:
         | 
| 121 | 
            +
                """Create the necessary ops to include autosummaries in TensorBoard report.
         | 
| 122 | 
            +
                Note: This should be done only once per graph.
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
                global _finalized
         | 
| 125 | 
            +
                tfutil.assert_tf_initialized()
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                if _finalized:
         | 
| 128 | 
            +
                    return None
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                _finalized = True
         | 
| 131 | 
            +
                tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                # Create summary ops.
         | 
| 134 | 
            +
                with tf.device(None), tf.control_dependencies(None):
         | 
| 135 | 
            +
                    for name, vars_list in _vars.items():
         | 
| 136 | 
            +
                        name_id = name.replace("/", "_")
         | 
| 137 | 
            +
                        with tfutil.absolute_name_scope("Autosummary/" + name_id):
         | 
| 138 | 
            +
                            moments = tf.add_n(vars_list)
         | 
| 139 | 
            +
                            moments /= moments[0]
         | 
| 140 | 
            +
                            with tf.control_dependencies([moments]):  # read before resetting
         | 
| 141 | 
            +
                                reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
         | 
| 142 | 
            +
                                with tf.name_scope(None), tf.control_dependencies(reset_ops):  # reset before reporting
         | 
| 143 | 
            +
                                    mean = moments[1]
         | 
| 144 | 
            +
                                    std = tf.sqrt(moments[2] - tf.square(moments[1]))
         | 
| 145 | 
            +
                                    tf.summary.scalar(name, mean)
         | 
| 146 | 
            +
                                    if enable_custom_scalars:
         | 
| 147 | 
            +
                                        tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
         | 
| 148 | 
            +
                                        tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                # Setup layout for custom scalars.
         | 
| 151 | 
            +
                layout = None
         | 
| 152 | 
            +
                if enable_custom_scalars:
         | 
| 153 | 
            +
                    cat_dict = OrderedDict()
         | 
| 154 | 
            +
                    for series_name in sorted(_vars.keys()):
         | 
| 155 | 
            +
                        p = series_name.split("/")
         | 
| 156 | 
            +
                        cat = p[0] if len(p) >= 2 else ""
         | 
| 157 | 
            +
                        chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
         | 
| 158 | 
            +
                        if cat not in cat_dict:
         | 
| 159 | 
            +
                            cat_dict[cat] = OrderedDict()
         | 
| 160 | 
            +
                        if chart not in cat_dict[cat]:
         | 
| 161 | 
            +
                            cat_dict[cat][chart] = []
         | 
| 162 | 
            +
                        cat_dict[cat][chart].append(series_name)
         | 
| 163 | 
            +
                    categories = []
         | 
| 164 | 
            +
                    for cat_name, chart_dict in cat_dict.items():
         | 
| 165 | 
            +
                        charts = []
         | 
| 166 | 
            +
                        for chart_name, series_names in chart_dict.items():
         | 
| 167 | 
            +
                            series = []
         | 
| 168 | 
            +
                            for series_name in series_names:
         | 
| 169 | 
            +
                                series.append(layout_pb2.MarginChartContent.Series(
         | 
| 170 | 
            +
                                    value=series_name,
         | 
| 171 | 
            +
                                    lower="xCustomScalars/" + series_name + "/margin_lo",
         | 
| 172 | 
            +
                                    upper="xCustomScalars/" + series_name + "/margin_hi"))
         | 
| 173 | 
            +
                            margin = layout_pb2.MarginChartContent(series=series)
         | 
| 174 | 
            +
                            charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
         | 
| 175 | 
            +
                        categories.append(layout_pb2.Category(title=cat_name, chart=charts))
         | 
| 176 | 
            +
                    layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
         | 
| 177 | 
            +
                return layout
         | 
| 178 | 
            +
             | 
| 179 | 
            +
            def save_summaries(file_writer, global_step=None):
         | 
| 180 | 
            +
                """Call FileWriter.add_summary() with all summaries in the default graph,
         | 
| 181 | 
            +
                automatically finalizing and merging them on the first call.
         | 
| 182 | 
            +
                """
         | 
| 183 | 
            +
                global _merge_op
         | 
| 184 | 
            +
                tfutil.assert_tf_initialized()
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                if _merge_op is None:
         | 
| 187 | 
            +
                    layout = finalize_autosummaries()
         | 
| 188 | 
            +
                    if layout is not None:
         | 
| 189 | 
            +
                        file_writer.add_summary(layout)
         | 
| 190 | 
            +
                    with tf.device(None), tf.control_dependencies(None):
         | 
| 191 | 
            +
                        _merge_op = tf.summary.merge_all()
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                file_writer.add_summary(_merge_op.eval(), global_step)
         | 
    	
        dnnlib/tflib/custom_ops.py
    ADDED
    
    | @@ -0,0 +1,171 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # This work is made available under the Nvidia Source Code License-NC.
         | 
| 6 | 
            +
            # To view a copy of this license, visit
         | 
| 7 | 
            +
            # https://nvlabs.github.io/stylegan2/license.html
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """TensorFlow custom ops builder.
         | 
| 10 | 
            +
            """
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import os
         | 
| 13 | 
            +
            import re
         | 
| 14 | 
            +
            import uuid
         | 
| 15 | 
            +
            import hashlib
         | 
| 16 | 
            +
            import tempfile
         | 
| 17 | 
            +
            import shutil
         | 
| 18 | 
            +
            import tensorflow as tf
         | 
| 19 | 
            +
            from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            #----------------------------------------------------------------------------
         | 
| 22 | 
            +
            # Global options.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache')
         | 
| 25 | 
            +
            cuda_cache_version_tag = 'v1'
         | 
| 26 | 
            +
            do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe!
         | 
| 27 | 
            +
            verbose = True # Print status messages to stdout.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            compiler_bindir_search_path = [
         | 
| 30 | 
            +
                'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64',
         | 
| 31 | 
            +
                'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64',
         | 
| 32 | 
            +
                'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin',
         | 
| 33 | 
            +
            ]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            #----------------------------------------------------------------------------
         | 
| 36 | 
            +
            # Internal helper funcs.
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            def _find_compiler_bindir():
         | 
| 39 | 
            +
                for compiler_path in compiler_bindir_search_path:
         | 
| 40 | 
            +
                    if os.path.isdir(compiler_path):
         | 
| 41 | 
            +
                        return compiler_path
         | 
| 42 | 
            +
                return None
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            def _get_compute_cap(device):
         | 
| 45 | 
            +
                caps_str = device.physical_device_desc
         | 
| 46 | 
            +
                m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
         | 
| 47 | 
            +
                major = m.group(1)
         | 
| 48 | 
            +
                minor = m.group(2)
         | 
| 49 | 
            +
                return (major, minor)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            def _get_cuda_gpu_arch_string():
         | 
| 52 | 
            +
                gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
         | 
| 53 | 
            +
                if len(gpus) == 0:
         | 
| 54 | 
            +
                    raise RuntimeError('No GPU devices found')
         | 
| 55 | 
            +
                (major, minor) = _get_compute_cap(gpus[0])
         | 
| 56 | 
            +
                return 'sm_%s%s' % (major, minor)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            def _run_cmd(cmd):
         | 
| 59 | 
            +
                with os.popen(cmd) as pipe:
         | 
| 60 | 
            +
                    output = pipe.read()
         | 
| 61 | 
            +
                    status = pipe.close()
         | 
| 62 | 
            +
                if status is not None:
         | 
| 63 | 
            +
                    raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            def _prepare_nvcc_cli(opts):
         | 
| 66 | 
            +
                cmd = 'nvcc ' + opts.strip()
         | 
| 67 | 
            +
                cmd += ' --disable-warnings'
         | 
| 68 | 
            +
                cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
         | 
| 69 | 
            +
                cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
         | 
| 70 | 
            +
                cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
         | 
| 71 | 
            +
                cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                compiler_bindir = _find_compiler_bindir()
         | 
| 74 | 
            +
                if compiler_bindir is None:
         | 
| 75 | 
            +
                    # Require that _find_compiler_bindir succeeds on Windows.  Allow
         | 
| 76 | 
            +
                    # nvcc to use whatever is the default on Linux.
         | 
| 77 | 
            +
                    if os.name == 'nt':
         | 
| 78 | 
            +
                        raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
         | 
| 79 | 
            +
                else:
         | 
| 80 | 
            +
                    cmd += ' --compiler-bindir "%s"' % compiler_bindir
         | 
| 81 | 
            +
                cmd += ' 2>&1'
         | 
| 82 | 
            +
                return cmd
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            #----------------------------------------------------------------------------
         | 
| 85 | 
            +
            # Main entry point.
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            _plugin_cache = dict()
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            def get_plugin(cuda_file):
         | 
| 90 | 
            +
                cuda_file_base = os.path.basename(cuda_file)
         | 
| 91 | 
            +
                cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                # Already in cache?
         | 
| 94 | 
            +
                if cuda_file in _plugin_cache:
         | 
| 95 | 
            +
                    return _plugin_cache[cuda_file]
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                # Setup plugin.
         | 
| 98 | 
            +
                if verbose:
         | 
| 99 | 
            +
                    print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
         | 
| 100 | 
            +
                try:
         | 
| 101 | 
            +
                    # Hash CUDA source.
         | 
| 102 | 
            +
                    md5 = hashlib.md5()
         | 
| 103 | 
            +
                    with open(cuda_file, 'rb') as f:
         | 
| 104 | 
            +
                        md5.update(f.read())
         | 
| 105 | 
            +
                    md5.update(b'\n')
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # Hash headers included by the CUDA code by running it through the preprocessor.
         | 
| 108 | 
            +
                    if not do_not_hash_included_headers:
         | 
| 109 | 
            +
                        if verbose:
         | 
| 110 | 
            +
                            print('Preprocessing... ', end='', flush=True)
         | 
| 111 | 
            +
                        with tempfile.TemporaryDirectory() as tmp_dir:
         | 
| 112 | 
            +
                            tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
         | 
| 113 | 
            +
                            _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
         | 
| 114 | 
            +
                            with open(tmp_file, 'rb') as f:
         | 
| 115 | 
            +
                                bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
         | 
| 116 | 
            +
                                good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
         | 
| 117 | 
            +
                                for ln in f:
         | 
| 118 | 
            +
                                    if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
         | 
| 119 | 
            +
                                        ln = ln.replace(bad_file_str, good_file_str)
         | 
| 120 | 
            +
                                        md5.update(ln)
         | 
| 121 | 
            +
                                md5.update(b'\n')
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # Select compiler options.
         | 
| 124 | 
            +
                    compile_opts = ''
         | 
| 125 | 
            +
                    if os.name == 'nt':
         | 
| 126 | 
            +
                        compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
         | 
| 127 | 
            +
                    elif os.name == 'posix':
         | 
| 128 | 
            +
                        compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so')
         | 
| 129 | 
            +
                        compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\''
         | 
| 130 | 
            +
                    else:
         | 
| 131 | 
            +
                        assert False # not Windows or Linux, w00t?
         | 
| 132 | 
            +
                    compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string()
         | 
| 133 | 
            +
                    compile_opts += ' --use_fast_math'
         | 
| 134 | 
            +
                    nvcc_cmd = _prepare_nvcc_cli(compile_opts)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    # Hash build configuration.
         | 
| 137 | 
            +
                    md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
         | 
| 138 | 
            +
                    md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
         | 
| 139 | 
            +
                    md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # Compile if not already compiled.
         | 
| 142 | 
            +
                    bin_file_ext = '.dll' if os.name == 'nt' else '.so'
         | 
| 143 | 
            +
                    bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
         | 
| 144 | 
            +
                    if not os.path.isfile(bin_file):
         | 
| 145 | 
            +
                        if verbose:
         | 
| 146 | 
            +
                            print('Compiling... ', end='', flush=True)
         | 
| 147 | 
            +
                        with tempfile.TemporaryDirectory() as tmp_dir:
         | 
| 148 | 
            +
                            tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
         | 
| 149 | 
            +
                            _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
         | 
| 150 | 
            +
                            os.makedirs(cuda_cache_path, exist_ok=True)
         | 
| 151 | 
            +
                            intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
         | 
| 152 | 
            +
                            shutil.copyfile(tmp_file, intermediate_file)
         | 
| 153 | 
            +
                            os.rename(intermediate_file, bin_file) # atomic
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    # Load.
         | 
| 156 | 
            +
                    if verbose:
         | 
| 157 | 
            +
                        print('Loading... ', end='', flush=True)
         | 
| 158 | 
            +
                    plugin = tf.load_op_library(bin_file)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    # Add to cache.
         | 
| 161 | 
            +
                    _plugin_cache[cuda_file] = plugin
         | 
| 162 | 
            +
                    if verbose:
         | 
| 163 | 
            +
                        print('Done.', flush=True)
         | 
| 164 | 
            +
                    return plugin
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                except:
         | 
| 167 | 
            +
                    if verbose:
         | 
| 168 | 
            +
                        print('Failed!', flush=True)
         | 
| 169 | 
            +
                    raise
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            #----------------------------------------------------------------------------
         | 
    	
        dnnlib/tflib/network.py
    ADDED
    
    | @@ -0,0 +1,592 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # This work is made available under the Nvidia Source Code License-NC.
         | 
| 6 | 
            +
            # To view a copy of this license, visit
         | 
| 7 | 
            +
            # https://nvlabs.github.io/stylegan2/license.html
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """Helper for managing networks."""
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import types
         | 
| 12 | 
            +
            import inspect
         | 
| 13 | 
            +
            import re
         | 
| 14 | 
            +
            import uuid
         | 
| 15 | 
            +
            import sys
         | 
| 16 | 
            +
            import numpy as np
         | 
| 17 | 
            +
            import tensorflow as tf
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from collections import OrderedDict
         | 
| 20 | 
            +
            from typing import Any, List, Tuple, Union
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from . import tfutil
         | 
| 23 | 
            +
            from .. import util
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from .tfutil import TfExpression, TfExpressionEx
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            _import_handlers = []  # Custom import handlers for dealing with legacy data in pickle import.
         | 
| 28 | 
            +
            _import_module_src = dict()  # Source code for temporary modules created during pickle import.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def import_handler(handler_func):
         | 
| 32 | 
            +
                """Function decorator for declaring custom import handlers."""
         | 
| 33 | 
            +
                _import_handlers.append(handler_func)
         | 
| 34 | 
            +
                return handler_func
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class Network:
         | 
| 38 | 
            +
                """Generic network abstraction.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                Acts as a convenience wrapper for a parameterized network construction
         | 
| 41 | 
            +
                function, providing several utility methods and convenient access to
         | 
| 42 | 
            +
                the inputs/outputs/weights.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                Network objects can be safely pickled and unpickled for long-term
         | 
| 45 | 
            +
                archival purposes. The pickling works reliably as long as the underlying
         | 
| 46 | 
            +
                network construction function is defined in a standalone Python module
         | 
| 47 | 
            +
                that has no side effects or application-specific imports.
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                Args:
         | 
| 50 | 
            +
                    name: Network name. Used to select TensorFlow name and variable scopes.
         | 
| 51 | 
            +
                    func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
         | 
| 52 | 
            +
                    static_kwargs: Keyword arguments to be passed in to the network construction function.
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                Attributes:
         | 
| 55 | 
            +
                    name: User-specified name, defaults to build func name if None.
         | 
| 56 | 
            +
                    scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
         | 
| 57 | 
            +
                    static_kwargs: Arguments passed to the user-supplied build func.
         | 
| 58 | 
            +
                    components: Container for sub-networks. Passed to the build func, and retained between calls.
         | 
| 59 | 
            +
                    num_inputs: Number of input tensors.
         | 
| 60 | 
            +
                    num_outputs: Number of output tensors.
         | 
| 61 | 
            +
                    input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
         | 
| 62 | 
            +
                    output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
         | 
| 63 | 
            +
                    input_shape: Short-hand for input_shapes[0].
         | 
| 64 | 
            +
                    output_shape: Short-hand for output_shapes[0].
         | 
| 65 | 
            +
                    input_templates: Input placeholders in the template graph.
         | 
| 66 | 
            +
                    output_templates: Output tensors in the template graph.
         | 
| 67 | 
            +
                    input_names: Name string for each input.
         | 
| 68 | 
            +
                    output_names: Name string for each output.
         | 
| 69 | 
            +
                    own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
         | 
| 70 | 
            +
                    vars: All variables (local_name => var).
         | 
| 71 | 
            +
                    trainables: All trainable variables (local_name => var).
         | 
| 72 | 
            +
                    var_global_to_local: Mapping from variable global names to local names.
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
         | 
| 76 | 
            +
                    tfutil.assert_tf_initialized()
         | 
| 77 | 
            +
                    assert isinstance(name, str) or name is None
         | 
| 78 | 
            +
                    assert func_name is not None
         | 
| 79 | 
            +
                    assert isinstance(func_name, str) or util.is_top_level_function(func_name)
         | 
| 80 | 
            +
                    assert util.is_pickleable(static_kwargs)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    self._init_fields()
         | 
| 83 | 
            +
                    self.name = name
         | 
| 84 | 
            +
                    self.static_kwargs = util.EasyDict(static_kwargs)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    # Locate the user-specified network build function.
         | 
| 87 | 
            +
                    if util.is_top_level_function(func_name):
         | 
| 88 | 
            +
                        func_name = util.get_top_level_function_name(func_name)
         | 
| 89 | 
            +
                    module, self._build_func_name = util.get_module_from_obj_name(func_name)
         | 
| 90 | 
            +
                    self._build_func = util.get_obj_from_module(module, self._build_func_name)
         | 
| 91 | 
            +
                    assert callable(self._build_func)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    # Dig up source code for the module containing the build function.
         | 
| 94 | 
            +
                    self._build_module_src = _import_module_src.get(module, None)
         | 
| 95 | 
            +
                    if self._build_module_src is None:
         | 
| 96 | 
            +
                        self._build_module_src = inspect.getsource(module)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    # Init TensorFlow graph.
         | 
| 99 | 
            +
                    self._init_graph()
         | 
| 100 | 
            +
                    self.reset_own_vars()
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def _init_fields(self) -> None:
         | 
| 103 | 
            +
                    self.name = None
         | 
| 104 | 
            +
                    self.scope = None
         | 
| 105 | 
            +
                    self.static_kwargs = util.EasyDict()
         | 
| 106 | 
            +
                    self.components = util.EasyDict()
         | 
| 107 | 
            +
                    self.num_inputs = 0
         | 
| 108 | 
            +
                    self.num_outputs = 0
         | 
| 109 | 
            +
                    self.input_shapes = [[]]
         | 
| 110 | 
            +
                    self.output_shapes = [[]]
         | 
| 111 | 
            +
                    self.input_shape = []
         | 
| 112 | 
            +
                    self.output_shape = []
         | 
| 113 | 
            +
                    self.input_templates = []
         | 
| 114 | 
            +
                    self.output_templates = []
         | 
| 115 | 
            +
                    self.input_names = []
         | 
| 116 | 
            +
                    self.output_names = []
         | 
| 117 | 
            +
                    self.own_vars = OrderedDict()
         | 
| 118 | 
            +
                    self.vars = OrderedDict()
         | 
| 119 | 
            +
                    self.trainables = OrderedDict()
         | 
| 120 | 
            +
                    self.var_global_to_local = OrderedDict()
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    self._build_func = None  # User-supplied build function that constructs the network.
         | 
| 123 | 
            +
                    self._build_func_name = None  # Name of the build function.
         | 
| 124 | 
            +
                    self._build_module_src = None  # Full source code of the module containing the build function.
         | 
| 125 | 
            +
                    self._run_cache = dict()  # Cached graph data for Network.run().
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def _init_graph(self) -> None:
         | 
| 128 | 
            +
                    # Collect inputs.
         | 
| 129 | 
            +
                    self.input_names = []
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    for param in inspect.signature(self._build_func).parameters.values():
         | 
| 132 | 
            +
                        if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
         | 
| 133 | 
            +
                            self.input_names.append(param.name)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    self.num_inputs = len(self.input_names)
         | 
| 136 | 
            +
                    assert self.num_inputs >= 1
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    # Choose name and scope.
         | 
| 139 | 
            +
                    if self.name is None:
         | 
| 140 | 
            +
                        self.name = self._build_func_name
         | 
| 141 | 
            +
                    assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
         | 
| 142 | 
            +
                    with tf.name_scope(None):
         | 
| 143 | 
            +
                        self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    # Finalize build func kwargs.
         | 
| 146 | 
            +
                    build_kwargs = dict(self.static_kwargs)
         | 
| 147 | 
            +
                    build_kwargs["is_template_graph"] = True
         | 
| 148 | 
            +
                    build_kwargs["components"] = self.components
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # Build template graph.
         | 
| 151 | 
            +
                    with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope):  # ignore surrounding scopes
         | 
| 152 | 
            +
                        assert tf.get_variable_scope().name == self.scope
         | 
| 153 | 
            +
                        assert tf.get_default_graph().get_name_scope() == self.scope
         | 
| 154 | 
            +
                        with tf.control_dependencies(None):  # ignore surrounding control dependencies
         | 
| 155 | 
            +
                            self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
         | 
| 156 | 
            +
                            out_expr = self._build_func(*self.input_templates, **build_kwargs)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    # Collect outputs.
         | 
| 159 | 
            +
                    assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
         | 
| 160 | 
            +
                    self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
         | 
| 161 | 
            +
                    self.num_outputs = len(self.output_templates)
         | 
| 162 | 
            +
                    assert self.num_outputs >= 1
         | 
| 163 | 
            +
                    assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    # Perform sanity checks.
         | 
| 166 | 
            +
                    if any(t.shape.ndims is None for t in self.input_templates):
         | 
| 167 | 
            +
                        raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
         | 
| 168 | 
            +
                    if any(t.shape.ndims is None for t in self.output_templates):
         | 
| 169 | 
            +
                        raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
         | 
| 170 | 
            +
                    if any(not isinstance(comp, Network) for comp in self.components.values()):
         | 
| 171 | 
            +
                        raise ValueError("Components of a Network must be Networks themselves.")
         | 
| 172 | 
            +
                    if len(self.components) != len(set(comp.name for comp in self.components.values())):
         | 
| 173 | 
            +
                        raise ValueError("Components of a Network must have unique names.")
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    # List inputs and outputs.
         | 
| 176 | 
            +
                    self.input_shapes = [t.shape.as_list() for t in self.input_templates]
         | 
| 177 | 
            +
                    self.output_shapes = [t.shape.as_list() for t in self.output_templates]
         | 
| 178 | 
            +
                    self.input_shape = self.input_shapes[0]
         | 
| 179 | 
            +
                    self.output_shape = self.output_shapes[0]
         | 
| 180 | 
            +
                    self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    # List variables.
         | 
| 183 | 
            +
                    self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
         | 
| 184 | 
            +
                    self.vars = OrderedDict(self.own_vars)
         | 
| 185 | 
            +
                    self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
         | 
| 186 | 
            +
                    self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
         | 
| 187 | 
            +
                    self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                def reset_own_vars(self) -> None:
         | 
| 190 | 
            +
                    """Re-initialize all variables of this network, excluding sub-networks."""
         | 
| 191 | 
            +
                    tfutil.run([var.initializer for var in self.own_vars.values()])
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def reset_vars(self) -> None:
         | 
| 194 | 
            +
                    """Re-initialize all variables of this network, including sub-networks."""
         | 
| 195 | 
            +
                    tfutil.run([var.initializer for var in self.vars.values()])
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def reset_trainables(self) -> None:
         | 
| 198 | 
            +
                    """Re-initialize all trainable variables of this network, including sub-networks."""
         | 
| 199 | 
            +
                    tfutil.run([var.initializer for var in self.trainables.values()])
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
         | 
| 202 | 
            +
                    """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
         | 
| 203 | 
            +
                    assert len(in_expr) == self.num_inputs
         | 
| 204 | 
            +
                    assert not all(expr is None for expr in in_expr)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    # Finalize build func kwargs.
         | 
| 207 | 
            +
                    build_kwargs = dict(self.static_kwargs)
         | 
| 208 | 
            +
                    build_kwargs.update(dynamic_kwargs)
         | 
| 209 | 
            +
                    build_kwargs["is_template_graph"] = False
         | 
| 210 | 
            +
                    build_kwargs["components"] = self.components
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    # Build TensorFlow graph to evaluate the network.
         | 
| 213 | 
            +
                    with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
         | 
| 214 | 
            +
                        assert tf.get_variable_scope().name == self.scope
         | 
| 215 | 
            +
                        valid_inputs = [expr for expr in in_expr if expr is not None]
         | 
| 216 | 
            +
                        final_inputs = []
         | 
| 217 | 
            +
                        for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
         | 
| 218 | 
            +
                            if expr is not None:
         | 
| 219 | 
            +
                                expr = tf.identity(expr, name=name)
         | 
| 220 | 
            +
                            else:
         | 
| 221 | 
            +
                                expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
         | 
| 222 | 
            +
                            final_inputs.append(expr)
         | 
| 223 | 
            +
                        out_expr = self._build_func(*final_inputs, **build_kwargs)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    # Propagate input shapes back to the user-specified expressions.
         | 
| 226 | 
            +
                    for expr, final in zip(in_expr, final_inputs):
         | 
| 227 | 
            +
                        if isinstance(expr, tf.Tensor):
         | 
| 228 | 
            +
                            expr.set_shape(final.shape)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # Express outputs in the desired format.
         | 
| 231 | 
            +
                    assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
         | 
| 232 | 
            +
                    if return_as_list:
         | 
| 233 | 
            +
                        out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
         | 
| 234 | 
            +
                    return out_expr
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
         | 
| 237 | 
            +
                    """Get the local name of a given variable, without any surrounding name scopes."""
         | 
| 238 | 
            +
                    assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
         | 
| 239 | 
            +
                    global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
         | 
| 240 | 
            +
                    return self.var_global_to_local[global_name]
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
         | 
| 243 | 
            +
                    """Find variable by local or global name."""
         | 
| 244 | 
            +
                    assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
         | 
| 245 | 
            +
                    return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
         | 
| 248 | 
            +
                    """Get the value of a given variable as NumPy array.
         | 
| 249 | 
            +
                    Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
         | 
| 250 | 
            +
                    return self.find_var(var_or_local_name).eval()
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
         | 
| 253 | 
            +
                    """Set the value of a given variable based on the given NumPy array.
         | 
| 254 | 
            +
                    Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
         | 
| 255 | 
            +
                    tfutil.set_vars({self.find_var(var_or_local_name): new_value})
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def __getstate__(self) -> dict:
         | 
| 258 | 
            +
                    """Pickle export."""
         | 
| 259 | 
            +
                    state = dict()
         | 
| 260 | 
            +
                    state["version"]            = 4
         | 
| 261 | 
            +
                    state["name"]               = self.name
         | 
| 262 | 
            +
                    state["static_kwargs"]      = dict(self.static_kwargs)
         | 
| 263 | 
            +
                    state["components"]         = dict(self.components)
         | 
| 264 | 
            +
                    state["build_module_src"]   = self._build_module_src
         | 
| 265 | 
            +
                    state["build_func_name"]    = self._build_func_name
         | 
| 266 | 
            +
                    state["variables"]          = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
         | 
| 267 | 
            +
                    return state
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def __setstate__(self, state: dict) -> None:
         | 
| 270 | 
            +
                    """Pickle import."""
         | 
| 271 | 
            +
                    # pylint: disable=attribute-defined-outside-init
         | 
| 272 | 
            +
                    tfutil.assert_tf_initialized()
         | 
| 273 | 
            +
                    self._init_fields()
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    # Execute custom import handlers.
         | 
| 276 | 
            +
                    for handler in _import_handlers:
         | 
| 277 | 
            +
                        state = handler(state)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    # Set basic fields.
         | 
| 280 | 
            +
                    assert state["version"] in [2, 3, 4]
         | 
| 281 | 
            +
                    self.name = state["name"]
         | 
| 282 | 
            +
                    self.static_kwargs = util.EasyDict(state["static_kwargs"])
         | 
| 283 | 
            +
                    self.components = util.EasyDict(state.get("components", {}))
         | 
| 284 | 
            +
                    self._build_module_src = state["build_module_src"]
         | 
| 285 | 
            +
                    self._build_func_name = state["build_func_name"]
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    # Create temporary module from the imported source code.
         | 
| 288 | 
            +
                    module_name = "_tflib_network_import_" + uuid.uuid4().hex
         | 
| 289 | 
            +
                    module = types.ModuleType(module_name)
         | 
| 290 | 
            +
                    sys.modules[module_name] = module
         | 
| 291 | 
            +
                    _import_module_src[module] = self._build_module_src
         | 
| 292 | 
            +
                    exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    # Locate network build function in the temporary module.
         | 
| 295 | 
            +
                    self._build_func = util.get_obj_from_module(module, self._build_func_name)
         | 
| 296 | 
            +
                    assert callable(self._build_func)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    # Init TensorFlow graph.
         | 
| 299 | 
            +
                    self._init_graph()
         | 
| 300 | 
            +
                    self.reset_own_vars()
         | 
| 301 | 
            +
                    tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                def clone(self, name: str = None, **new_static_kwargs) -> "Network":
         | 
| 304 | 
            +
                    """Create a clone of this network with its own copy of the variables."""
         | 
| 305 | 
            +
                    # pylint: disable=protected-access
         | 
| 306 | 
            +
                    net = object.__new__(Network)
         | 
| 307 | 
            +
                    net._init_fields()
         | 
| 308 | 
            +
                    net.name = name if name is not None else self.name
         | 
| 309 | 
            +
                    net.static_kwargs = util.EasyDict(self.static_kwargs)
         | 
| 310 | 
            +
                    net.static_kwargs.update(new_static_kwargs)
         | 
| 311 | 
            +
                    net._build_module_src = self._build_module_src
         | 
| 312 | 
            +
                    net._build_func_name = self._build_func_name
         | 
| 313 | 
            +
                    net._build_func = self._build_func
         | 
| 314 | 
            +
                    net._init_graph()
         | 
| 315 | 
            +
                    net.copy_vars_from(self)
         | 
| 316 | 
            +
                    return net
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                def copy_own_vars_from(self, src_net: "Network") -> None:
         | 
| 319 | 
            +
                    """Copy the values of all variables from the given network, excluding sub-networks."""
         | 
| 320 | 
            +
                    names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
         | 
| 321 | 
            +
                    tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                def copy_vars_from(self, src_net: "Network") -> None:
         | 
| 324 | 
            +
                    """Copy the values of all variables from the given network, including sub-networks."""
         | 
| 325 | 
            +
                    names = [name for name in self.vars.keys() if name in src_net.vars]
         | 
| 326 | 
            +
                    tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                def copy_trainables_from(self, src_net: "Network") -> None:
         | 
| 329 | 
            +
                    """Copy the values of all trainable variables from the given network, including sub-networks."""
         | 
| 330 | 
            +
                    names = [name for name in self.trainables.keys() if name in src_net.trainables]
         | 
| 331 | 
            +
                    tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
         | 
| 334 | 
            +
                    """Create new network with the given parameters, and copy all variables from this network."""
         | 
| 335 | 
            +
                    if new_name is None:
         | 
| 336 | 
            +
                        new_name = self.name
         | 
| 337 | 
            +
                    static_kwargs = dict(self.static_kwargs)
         | 
| 338 | 
            +
                    static_kwargs.update(new_static_kwargs)
         | 
| 339 | 
            +
                    net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
         | 
| 340 | 
            +
                    net.copy_vars_from(self)
         | 
| 341 | 
            +
                    return net
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
         | 
| 344 | 
            +
                    """Construct a TensorFlow op that updates the variables of this network
         | 
| 345 | 
            +
                    to be slightly closer to those of the given network."""
         | 
| 346 | 
            +
                    with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
         | 
| 347 | 
            +
                        ops = []
         | 
| 348 | 
            +
                        for name, var in self.vars.items():
         | 
| 349 | 
            +
                            if name in src_net.vars:
         | 
| 350 | 
            +
                                cur_beta = beta if name in self.trainables else beta_nontrainable
         | 
| 351 | 
            +
                                new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
         | 
| 352 | 
            +
                                ops.append(var.assign(new_value))
         | 
| 353 | 
            +
                        return tf.group(*ops)
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                def run(self,
         | 
| 356 | 
            +
                        *in_arrays: Tuple[Union[np.ndarray, None], ...],
         | 
| 357 | 
            +
                        input_transform: dict = None,
         | 
| 358 | 
            +
                        output_transform: dict = None,
         | 
| 359 | 
            +
                        return_as_list: bool = False,
         | 
| 360 | 
            +
                        print_progress: bool = False,
         | 
| 361 | 
            +
                        minibatch_size: int = None,
         | 
| 362 | 
            +
                        num_gpus: int = 1,
         | 
| 363 | 
            +
                        assume_frozen: bool = False,
         | 
| 364 | 
            +
                        **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
         | 
| 365 | 
            +
                    """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    Args:
         | 
| 368 | 
            +
                        input_transform:    A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
         | 
| 369 | 
            +
                                            The dict must contain a 'func' field that points to a top-level function. The function is called with the input
         | 
| 370 | 
            +
                                            TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
         | 
| 371 | 
            +
                        output_transform:   A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
         | 
| 372 | 
            +
                                            The dict must contain a 'func' field that points to a top-level function. The function is called with the output
         | 
| 373 | 
            +
                                            TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
         | 
| 374 | 
            +
                        return_as_list:     True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
         | 
| 375 | 
            +
                        print_progress:     Print progress to the console? Useful for very large input arrays.
         | 
| 376 | 
            +
                        minibatch_size:     Maximum minibatch size to use, None = disable batching.
         | 
| 377 | 
            +
                        num_gpus:           Number of GPUs to use.
         | 
| 378 | 
            +
                        assume_frozen:      Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
         | 
| 379 | 
            +
                        dynamic_kwargs:     Additional keyword arguments to be passed into the network build function.
         | 
| 380 | 
            +
                    """
         | 
| 381 | 
            +
                    assert len(in_arrays) == self.num_inputs
         | 
| 382 | 
            +
                    assert not all(arr is None for arr in in_arrays)
         | 
| 383 | 
            +
                    assert input_transform is None or util.is_top_level_function(input_transform["func"])
         | 
| 384 | 
            +
                    assert output_transform is None or util.is_top_level_function(output_transform["func"])
         | 
| 385 | 
            +
                    output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
         | 
| 386 | 
            +
                    num_items = in_arrays[0].shape[0]
         | 
| 387 | 
            +
                    if minibatch_size is None:
         | 
| 388 | 
            +
                        minibatch_size = num_items
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    # Construct unique hash key from all arguments that affect the TensorFlow graph.
         | 
| 391 | 
            +
                    key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
         | 
| 392 | 
            +
                    def unwind_key(obj):
         | 
| 393 | 
            +
                        if isinstance(obj, dict):
         | 
| 394 | 
            +
                            return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
         | 
| 395 | 
            +
                        if callable(obj):
         | 
| 396 | 
            +
                            return util.get_top_level_function_name(obj)
         | 
| 397 | 
            +
                        return obj
         | 
| 398 | 
            +
                    key = repr(unwind_key(key))
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    # Build graph.
         | 
| 401 | 
            +
                    if key not in self._run_cache:
         | 
| 402 | 
            +
                        with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
         | 
| 403 | 
            +
                            with tf.device("/cpu:0"):
         | 
| 404 | 
            +
                                in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
         | 
| 405 | 
            +
                                in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                            out_split = []
         | 
| 408 | 
            +
                            for gpu in range(num_gpus):
         | 
| 409 | 
            +
                                with tf.device("/gpu:%d" % gpu):
         | 
| 410 | 
            +
                                    net_gpu = self.clone() if assume_frozen else self
         | 
| 411 | 
            +
                                    in_gpu = in_split[gpu]
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                                    if input_transform is not None:
         | 
| 414 | 
            +
                                        in_kwargs = dict(input_transform)
         | 
| 415 | 
            +
                                        in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
         | 
| 416 | 
            +
                                        in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                                    assert len(in_gpu) == self.num_inputs
         | 
| 419 | 
            +
                                    out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                                    if output_transform is not None:
         | 
| 422 | 
            +
                                        out_kwargs = dict(output_transform)
         | 
| 423 | 
            +
                                        out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
         | 
| 424 | 
            +
                                        out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                                    assert len(out_gpu) == self.num_outputs
         | 
| 427 | 
            +
                                    out_split.append(out_gpu)
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                            with tf.device("/cpu:0"):
         | 
| 430 | 
            +
                                out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
         | 
| 431 | 
            +
                                self._run_cache[key] = in_expr, out_expr
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    # Run minibatches.
         | 
| 434 | 
            +
                    in_expr, out_expr = self._run_cache[key]
         | 
| 435 | 
            +
                    out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    for mb_begin in range(0, num_items, minibatch_size):
         | 
| 438 | 
            +
                        if print_progress:
         | 
| 439 | 
            +
                            print("\r%d / %d" % (mb_begin, num_items), end="")
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                        mb_end = min(mb_begin + minibatch_size, num_items)
         | 
| 442 | 
            +
                        mb_num = mb_end - mb_begin
         | 
| 443 | 
            +
                        mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
         | 
| 444 | 
            +
                        mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                        for dst, src in zip(out_arrays, mb_out):
         | 
| 447 | 
            +
                            dst[mb_begin: mb_end] = src
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    # Done.
         | 
| 450 | 
            +
                    if print_progress:
         | 
| 451 | 
            +
                        print("\r%d / %d" % (num_items, num_items))
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                    if not return_as_list:
         | 
| 454 | 
            +
                        out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
         | 
| 455 | 
            +
                    return out_arrays
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                def list_ops(self) -> List[TfExpression]:
         | 
| 458 | 
            +
                    include_prefix = self.scope + "/"
         | 
| 459 | 
            +
                    exclude_prefix = include_prefix + "_"
         | 
| 460 | 
            +
                    ops = tf.get_default_graph().get_operations()
         | 
| 461 | 
            +
                    ops = [op for op in ops if op.name.startswith(include_prefix)]
         | 
| 462 | 
            +
                    ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
         | 
| 463 | 
            +
                    return ops
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
         | 
| 466 | 
            +
                    """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
         | 
| 467 | 
            +
                    individual layers of the network. Mainly intended to be used for reporting."""
         | 
| 468 | 
            +
                    layers = []
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    def recurse(scope, parent_ops, parent_vars, level):
         | 
| 471 | 
            +
                        # Ignore specific patterns.
         | 
| 472 | 
            +
                        if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
         | 
| 473 | 
            +
                            return
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                        # Filter ops and vars by scope.
         | 
| 476 | 
            +
                        global_prefix = scope + "/"
         | 
| 477 | 
            +
                        local_prefix = global_prefix[len(self.scope) + 1:]
         | 
| 478 | 
            +
                        cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
         | 
| 479 | 
            +
                        cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
         | 
| 480 | 
            +
                        if not cur_ops and not cur_vars:
         | 
| 481 | 
            +
                            return
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                        # Filter out all ops related to variables.
         | 
| 484 | 
            +
                        for var in [op for op in cur_ops if op.type.startswith("Variable")]:
         | 
| 485 | 
            +
                            var_prefix = var.name + "/"
         | 
| 486 | 
            +
                            cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                        # Scope does not contain ops as immediate children => recurse deeper.
         | 
| 489 | 
            +
                        contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops)
         | 
| 490 | 
            +
                        if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
         | 
| 491 | 
            +
                            visited = set()
         | 
| 492 | 
            +
                            for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
         | 
| 493 | 
            +
                                token = rel_name.split("/")[0]
         | 
| 494 | 
            +
                                if token not in visited:
         | 
| 495 | 
            +
                                    recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
         | 
| 496 | 
            +
                                    visited.add(token)
         | 
| 497 | 
            +
                            return
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                        # Report layer.
         | 
| 500 | 
            +
                        layer_name = scope[len(self.scope) + 1:]
         | 
| 501 | 
            +
                        layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
         | 
| 502 | 
            +
                        layer_trainables = [var for _name, var in cur_vars if var.trainable]
         | 
| 503 | 
            +
                        layers.append((layer_name, layer_output, layer_trainables))
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                    recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
         | 
| 506 | 
            +
                    return layers
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
         | 
| 509 | 
            +
                    """Print a summary table of the network structure."""
         | 
| 510 | 
            +
                    rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
         | 
| 511 | 
            +
                    rows += [["---"] * 4]
         | 
| 512 | 
            +
                    total_params = 0
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                    for layer_name, layer_output, layer_trainables in self.list_layers():
         | 
| 515 | 
            +
                        num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables)
         | 
| 516 | 
            +
                        weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
         | 
| 517 | 
            +
                        weights.sort(key=lambda x: len(x.name))
         | 
| 518 | 
            +
                        if len(weights) == 0 and len(layer_trainables) == 1:
         | 
| 519 | 
            +
                            weights = layer_trainables
         | 
| 520 | 
            +
                        total_params += num_params
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                        if not hide_layers_with_no_params or num_params != 0:
         | 
| 523 | 
            +
                            num_params_str = str(num_params) if num_params > 0 else "-"
         | 
| 524 | 
            +
                            output_shape_str = str(layer_output.shape)
         | 
| 525 | 
            +
                            weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
         | 
| 526 | 
            +
                            rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    rows += [["---"] * 4]
         | 
| 529 | 
            +
                    rows += [["Total", str(total_params), "", ""]]
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                    widths = [max(len(cell) for cell in column) for column in zip(*rows)]
         | 
| 532 | 
            +
                    print()
         | 
| 533 | 
            +
                    for row in rows:
         | 
| 534 | 
            +
                        print("  ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
         | 
| 535 | 
            +
                    print()
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                def setup_weight_histograms(self, title: str = None) -> None:
         | 
| 538 | 
            +
                    """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
         | 
| 539 | 
            +
                    if title is None:
         | 
| 540 | 
            +
                        title = self.name
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                    with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
         | 
| 543 | 
            +
                        for local_name, var in self.trainables.items():
         | 
| 544 | 
            +
                            if "/" in local_name:
         | 
| 545 | 
            +
                                p = local_name.split("/")
         | 
| 546 | 
            +
                                name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
         | 
| 547 | 
            +
                            else:
         | 
| 548 | 
            +
                                name = title + "_toplevel/" + local_name
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                            tf.summary.histogram(name, var)
         | 
| 551 | 
            +
             | 
| 552 | 
            +
            #----------------------------------------------------------------------------
         | 
| 553 | 
            +
            # Backwards-compatible emulation of legacy output transformation in Network.run().
         | 
| 554 | 
            +
             | 
| 555 | 
            +
            _print_legacy_warning = True
         | 
| 556 | 
            +
             | 
| 557 | 
            +
            def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
         | 
| 558 | 
            +
                global _print_legacy_warning
         | 
| 559 | 
            +
                legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
         | 
| 560 | 
            +
                if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
         | 
| 561 | 
            +
                    return output_transform, dynamic_kwargs
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                if _print_legacy_warning:
         | 
| 564 | 
            +
                    _print_legacy_warning = False
         | 
| 565 | 
            +
                    print()
         | 
| 566 | 
            +
                    print("WARNING: Old-style output transformations in Network.run() are deprecated.")
         | 
| 567 | 
            +
                    print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
         | 
| 568 | 
            +
                    print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
         | 
| 569 | 
            +
                    print()
         | 
| 570 | 
            +
                assert output_transform is None
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                new_kwargs = dict(dynamic_kwargs)
         | 
| 573 | 
            +
                new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
         | 
| 574 | 
            +
                new_transform["func"] = _legacy_output_transform_func
         | 
| 575 | 
            +
                return new_transform, new_kwargs
         | 
| 576 | 
            +
             | 
| 577 | 
            +
            def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
         | 
| 578 | 
            +
                if out_mul != 1.0:
         | 
| 579 | 
            +
                    expr = [x * out_mul for x in expr]
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                if out_add != 0.0:
         | 
| 582 | 
            +
                    expr = [x + out_add for x in expr]
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                if out_shrink > 1:
         | 
| 585 | 
            +
                    ksize = [1, 1, out_shrink, out_shrink]
         | 
| 586 | 
            +
                    expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                if out_dtype is not None:
         | 
| 589 | 
            +
                    if tf.as_dtype(out_dtype).is_integer:
         | 
| 590 | 
            +
                        expr = [tf.round(x) for x in expr]
         | 
| 591 | 
            +
                    expr = [tf.saturate_cast(x, out_dtype) for x in expr]
         | 
| 592 | 
            +
                return expr
         | 
    	
        dnnlib/tflib/ops/__init__.py
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # This work is made available under the Nvidia Source Code License-NC.
         | 
| 6 | 
            +
            # To view a copy of this license, visit
         | 
| 7 | 
            +
            # https://nvlabs.github.io/stylegan2/license.html
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # empty
         | 
    	
        dnnlib/tflib/ops/fused_bias_act.cu
    ADDED
    
    | @@ -0,0 +1,190 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            // Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 4 | 
            +
            //
         | 
| 5 | 
            +
            // This work is made available under the Nvidia Source Code License-NC.
         | 
| 6 | 
            +
            // To view a copy of this license, visit
         | 
| 7 | 
            +
            // https://nvlabs.github.io/stylegan2/license.html
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            #define EIGEN_USE_GPU
         | 
| 10 | 
            +
            #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
         | 
| 11 | 
            +
            #include "tensorflow/core/framework/op.h"
         | 
| 12 | 
            +
            #include "tensorflow/core/framework/op_kernel.h"
         | 
| 13 | 
            +
            #include "tensorflow/core/framework/shape_inference.h"
         | 
| 14 | 
            +
            #include <stdio.h>
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            using namespace tensorflow;
         | 
| 17 | 
            +
            using namespace tensorflow::shape_inference;
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            //------------------------------------------------------------------------
         | 
| 22 | 
            +
            // CUDA kernel.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            template <class T>
         | 
| 25 | 
            +
            struct FusedBiasActKernelParams
         | 
| 26 | 
            +
            {
         | 
| 27 | 
            +
                const T*    x;      // [sizeX]
         | 
| 28 | 
            +
                const T*    b;      // [sizeB] or NULL
         | 
| 29 | 
            +
                const T*    ref;    // [sizeX] or NULL
         | 
| 30 | 
            +
                T*          y;      // [sizeX]
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                int         grad;
         | 
| 33 | 
            +
                int         axis;
         | 
| 34 | 
            +
                int         act;
         | 
| 35 | 
            +
                float       alpha;
         | 
| 36 | 
            +
                float       gain;
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                int         sizeX;
         | 
| 39 | 
            +
                int         sizeB;
         | 
| 40 | 
            +
                int         stepB;
         | 
| 41 | 
            +
                int         loopX;
         | 
| 42 | 
            +
            };
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            template <class T>
         | 
| 45 | 
            +
            static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams<T> p)
         | 
| 46 | 
            +
            {
         | 
| 47 | 
            +
                const float expRange        = 80.0f;
         | 
| 48 | 
            +
                const float halfExpRange    = 40.0f;
         | 
| 49 | 
            +
                const float seluScale       = 1.0507009873554804934193349852946f;
         | 
| 50 | 
            +
                const float seluAlpha       = 1.6732632423543772848170429916717f;
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                // Loop over elements.
         | 
| 53 | 
            +
                int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
         | 
| 54 | 
            +
                for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
         | 
| 55 | 
            +
                {
         | 
| 56 | 
            +
                    // Load and apply bias.
         | 
| 57 | 
            +
                    float x = (float)p.x[xi];
         | 
| 58 | 
            +
                    if (p.b)
         | 
| 59 | 
            +
                        x += (float)p.b[(xi / p.stepB) % p.sizeB];
         | 
| 60 | 
            +
                    float ref = (p.ref) ? (float)p.ref[xi] : 0.0f;
         | 
| 61 | 
            +
                    if (p.gain != 0.0f & p.act != 9)
         | 
| 62 | 
            +
                        ref /= p.gain;
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    // Evaluate activation func.
         | 
| 65 | 
            +
                    float y;
         | 
| 66 | 
            +
                    switch (p.act * 10 + p.grad)
         | 
| 67 | 
            +
                    {
         | 
| 68 | 
            +
                        // linear
         | 
| 69 | 
            +
                        default:
         | 
| 70 | 
            +
                        case 10: y = x; break;
         | 
| 71 | 
            +
                        case 11: y = x; break;
         | 
| 72 | 
            +
                        case 12: y = 0.0f; break;
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                        // relu
         | 
| 75 | 
            +
                        case 20: y = (x > 0.0f) ? x : 0.0f; break;
         | 
| 76 | 
            +
                        case 21: y = (ref > 0.0f) ? x : 0.0f; break;
         | 
| 77 | 
            +
                        case 22: y = 0.0f; break;
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                        // lrelu
         | 
| 80 | 
            +
                        case 30: y = (x > 0.0f) ? x : x * p.alpha; break;
         | 
| 81 | 
            +
                        case 31: y = (ref > 0.0f) ? x : x * p.alpha; break;
         | 
| 82 | 
            +
                        case 32: y = 0.0f; break;
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                        // tanh
         | 
| 85 | 
            +
                        case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break;
         | 
| 86 | 
            +
                        case 41: y = x * (1.0f - ref * ref); break;
         | 
| 87 | 
            +
                        case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break;
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                        // sigmoid
         | 
| 90 | 
            +
                        case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break;
         | 
| 91 | 
            +
                        case 51: y = x * ref * (1.0f - ref); break;
         | 
| 92 | 
            +
                        case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break;
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                        // elu
         | 
| 95 | 
            +
                        case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break;
         | 
| 96 | 
            +
                        case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break;
         | 
| 97 | 
            +
                        case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break;
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                        // selu
         | 
| 100 | 
            +
                        case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break;
         | 
| 101 | 
            +
                        case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break;
         | 
| 102 | 
            +
                        case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break;
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                        // softplus
         | 
| 105 | 
            +
                        case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break;
         | 
| 106 | 
            +
                        case 81: y = x * (1.0f - expf(-ref)); break;
         | 
| 107 | 
            +
                        case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break;
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                        // swish
         | 
| 110 | 
            +
                        case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break;
         | 
| 111 | 
            +
                        case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break;
         | 
| 112 | 
            +
                        case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break;
         | 
| 113 | 
            +
                    }
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    // Apply gain and store.
         | 
| 116 | 
            +
                    p.y[xi] = (T)(y * p.gain);
         | 
| 117 | 
            +
                }
         | 
| 118 | 
            +
            }
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            //------------------------------------------------------------------------
         | 
| 121 | 
            +
            // TensorFlow op.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
            template <class T>
         | 
| 124 | 
            +
            struct FusedBiasActOp : public OpKernel
         | 
| 125 | 
            +
            {
         | 
| 126 | 
            +
                FusedBiasActKernelParams<T> m_attribs;
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx)
         | 
| 129 | 
            +
                {
         | 
| 130 | 
            +
                    memset(&m_attribs, 0, sizeof(m_attribs));
         | 
| 131 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad));
         | 
| 132 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis));
         | 
| 133 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act));
         | 
| 134 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha));
         | 
| 135 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain));
         | 
| 136 | 
            +
                    OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative"));
         | 
| 137 | 
            +
                    OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative"));
         | 
| 138 | 
            +
                    OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative"));
         | 
| 139 | 
            +
                }
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                void Compute(OpKernelContext* ctx)
         | 
| 142 | 
            +
                {
         | 
| 143 | 
            +
                    FusedBiasActKernelParams<T> p = m_attribs;
         | 
| 144 | 
            +
                    cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    const Tensor& x     = ctx->input(0); // [...]
         | 
| 147 | 
            +
                    const Tensor& b     = ctx->input(1); // [sizeB] or [0]
         | 
| 148 | 
            +
                    const Tensor& ref   = ctx->input(2); // x.shape or [0]
         | 
| 149 | 
            +
                    p.x = x.flat<T>().data();
         | 
| 150 | 
            +
                    p.b = (b.NumElements()) ? b.flat<T>().data() : NULL;
         | 
| 151 | 
            +
                    p.ref = (ref.NumElements()) ? ref.flat<T>().data() : NULL;
         | 
| 152 | 
            +
                    OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds"));
         | 
| 153 | 
            +
                    OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1"));
         | 
| 154 | 
            +
                    OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements"));
         | 
| 155 | 
            +
                    OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements"));
         | 
| 156 | 
            +
                    OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large"));
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    p.sizeX = (int)x.NumElements();
         | 
| 159 | 
            +
                    p.sizeB = (int)b.NumElements();
         | 
| 160 | 
            +
                    p.stepB = 1;
         | 
| 161 | 
            +
                    for (int i = m_attribs.axis + 1; i < x.dims(); i++)
         | 
| 162 | 
            +
                        p.stepB *= (int)x.dim_size(i);
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    Tensor* y = NULL; // x.shape
         | 
| 165 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
         | 
| 166 | 
            +
                    p.y = y->flat<T>().data();
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    p.loopX = 4;
         | 
| 169 | 
            +
                    int blockSize = 4 * 32;
         | 
| 170 | 
            +
                    int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
         | 
| 171 | 
            +
                    void* args[] = {&p};
         | 
| 172 | 
            +
                    OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel<T>, gridSize, blockSize, args, 0, stream));
         | 
| 173 | 
            +
                }
         | 
| 174 | 
            +
            };
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            REGISTER_OP("FusedBiasAct")
         | 
| 177 | 
            +
                .Input      ("x: T")
         | 
| 178 | 
            +
                .Input      ("b: T")
         | 
| 179 | 
            +
                .Input      ("ref: T")
         | 
| 180 | 
            +
                .Output     ("y: T")
         | 
| 181 | 
            +
                .Attr       ("T: {float, half}")
         | 
| 182 | 
            +
                .Attr       ("grad: int = 0")
         | 
| 183 | 
            +
                .Attr       ("axis: int = 1")
         | 
| 184 | 
            +
                .Attr       ("act: int = 0")
         | 
| 185 | 
            +
                .Attr       ("alpha: float = 0.0")
         | 
| 186 | 
            +
                .Attr       ("gain: float = 1.0");
         | 
| 187 | 
            +
            REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<float>("T"), FusedBiasActOp<float>);
         | 
| 188 | 
            +
            REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), FusedBiasActOp<Eigen::half>);
         | 
| 189 | 
            +
             | 
| 190 | 
            +
            //------------------------------------------------------------------------
         | 
    	
        dnnlib/tflib/ops/fused_bias_act.py
    ADDED
    
    | @@ -0,0 +1,198 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # This work is made available under the Nvidia Source Code License-NC.
         | 
| 6 | 
            +
            # To view a copy of this license, visit
         | 
| 7 | 
            +
            # https://nvlabs.github.io/stylegan2/license.html
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """Custom TensorFlow ops for efficient bias and activation."""
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import tensorflow as tf
         | 
| 14 | 
            +
            from .. import custom_ops
         | 
| 15 | 
            +
            from ...util import EasyDict
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            def _get_plugin():
         | 
| 18 | 
            +
                return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            #----------------------------------------------------------------------------
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            activation_funcs = {
         | 
| 23 | 
            +
                'linear':   EasyDict(func=lambda x, **_:        x,                          def_alpha=None, def_gain=1.0,           cuda_idx=1, ref='y', zero_2nd_grad=True),
         | 
| 24 | 
            +
                'relu':     EasyDict(func=lambda x, **_:        tf.nn.relu(x),              def_alpha=None, def_gain=np.sqrt(2),    cuda_idx=2, ref='y', zero_2nd_grad=True),
         | 
| 25 | 
            +
                'lrelu':    EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2,  def_gain=np.sqrt(2),    cuda_idx=3, ref='y', zero_2nd_grad=True),
         | 
| 26 | 
            +
                'tanh':     EasyDict(func=lambda x, **_:        tf.nn.tanh(x),              def_alpha=None, def_gain=1.0,           cuda_idx=4, ref='y', zero_2nd_grad=False),
         | 
| 27 | 
            +
                'sigmoid':  EasyDict(func=lambda x, **_:        tf.nn.sigmoid(x),           def_alpha=None, def_gain=1.0,           cuda_idx=5, ref='y', zero_2nd_grad=False),
         | 
| 28 | 
            +
                'elu':      EasyDict(func=lambda x, **_:        tf.nn.elu(x),               def_alpha=None, def_gain=1.0,           cuda_idx=6, ref='y', zero_2nd_grad=False),
         | 
| 29 | 
            +
                'selu':     EasyDict(func=lambda x, **_:        tf.nn.selu(x),              def_alpha=None, def_gain=1.0,           cuda_idx=7, ref='y', zero_2nd_grad=False),
         | 
| 30 | 
            +
                'softplus': EasyDict(func=lambda x, **_:        tf.nn.softplus(x),          def_alpha=None, def_gain=1.0,           cuda_idx=8, ref='y', zero_2nd_grad=False),
         | 
| 31 | 
            +
                'swish':    EasyDict(func=lambda x, **_:        tf.nn.sigmoid(x) * x,       def_alpha=None, def_gain=np.sqrt(2),    cuda_idx=9, ref='x', zero_2nd_grad=False),
         | 
| 32 | 
            +
            }
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            #----------------------------------------------------------------------------
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'):
         | 
| 37 | 
            +
                r"""Fused bias and activation function.
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
         | 
| 40 | 
            +
                and scales the result by `gain`. Each of the steps is optional. In most cases,
         | 
| 41 | 
            +
                the fused op is considerably more efficient than performing the same calculation
         | 
| 42 | 
            +
                using standard TensorFlow ops. It supports first and second order gradients,
         | 
| 43 | 
            +
                but not third order gradients.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                Args:
         | 
| 46 | 
            +
                    x:      Input activation tensor. Can have any shape, but if `b` is defined, the
         | 
| 47 | 
            +
                            dimension corresponding to `axis`, as well as the rank, must be known.
         | 
| 48 | 
            +
                    b:      Bias vector, or `None` to disable. Must be a 1D tensor of the same type
         | 
| 49 | 
            +
                            as `x`. The shape must be known, and it must match the dimension of `x`
         | 
| 50 | 
            +
                            corresponding to `axis`.
         | 
| 51 | 
            +
                    axis:   The dimension in `x` corresponding to the elements of `b`.
         | 
| 52 | 
            +
                            The value of `axis` is ignored if `b` is not specified.
         | 
| 53 | 
            +
                    act:    Name of the activation function to evaluate, or `"linear"` to disable.
         | 
| 54 | 
            +
                            Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
         | 
| 55 | 
            +
                            See `activation_funcs` for a full list. `None` is not allowed.
         | 
| 56 | 
            +
                    alpha:  Shape parameter for the activation function, or `None` to use the default.
         | 
| 57 | 
            +
                    gain:   Scaling factor for the output tensor, or `None` to use default.
         | 
| 58 | 
            +
                            See `activation_funcs` for the default scaling of each activation function.
         | 
| 59 | 
            +
                            If unsure, consider specifying `1.0`.
         | 
| 60 | 
            +
                    impl:   Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                Returns:
         | 
| 63 | 
            +
                    Tensor of the same shape and datatype as `x`.
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                impl_dict = {
         | 
| 67 | 
            +
                    'ref':  _fused_bias_act_ref,
         | 
| 68 | 
            +
                    'cuda': _fused_bias_act_cuda,
         | 
| 69 | 
            +
                }
         | 
| 70 | 
            +
                return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            #----------------------------------------------------------------------------
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            def _fused_bias_act_ref(x, b, axis, act, alpha, gain):
         | 
| 75 | 
            +
                """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops."""
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                # Validate arguments.
         | 
| 78 | 
            +
                x = tf.convert_to_tensor(x)
         | 
| 79 | 
            +
                b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype)
         | 
| 80 | 
            +
                act_spec = activation_funcs[act]
         | 
| 81 | 
            +
                assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
         | 
| 82 | 
            +
                assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
         | 
| 83 | 
            +
                if alpha is None:
         | 
| 84 | 
            +
                    alpha = act_spec.def_alpha
         | 
| 85 | 
            +
                if gain is None:
         | 
| 86 | 
            +
                    gain = act_spec.def_gain
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                # Add bias.
         | 
| 89 | 
            +
                if b.shape[0] != 0:
         | 
| 90 | 
            +
                    x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)])
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                # Evaluate activation function.
         | 
| 93 | 
            +
                x = act_spec.func(x, alpha=alpha)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                # Scale by gain.
         | 
| 96 | 
            +
                if gain != 1:
         | 
| 97 | 
            +
                    x *= gain
         | 
| 98 | 
            +
                return x
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            #----------------------------------------------------------------------------
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            def _fused_bias_act_cuda(x, b, axis, act, alpha, gain):
         | 
| 103 | 
            +
                """Fast CUDA implementation of `fused_bias_act()` using custom ops."""
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                # Validate arguments.
         | 
| 106 | 
            +
                x = tf.convert_to_tensor(x)
         | 
| 107 | 
            +
                empty_tensor = tf.constant([], dtype=x.dtype)
         | 
| 108 | 
            +
                b = tf.convert_to_tensor(b) if b is not None else empty_tensor
         | 
| 109 | 
            +
                act_spec = activation_funcs[act]
         | 
| 110 | 
            +
                assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
         | 
| 111 | 
            +
                assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
         | 
| 112 | 
            +
                if alpha is None:
         | 
| 113 | 
            +
                    alpha = act_spec.def_alpha
         | 
| 114 | 
            +
                if gain is None:
         | 
| 115 | 
            +
                    gain = act_spec.def_gain
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                # Special cases.
         | 
| 118 | 
            +
                if act == 'linear' and b is None and gain == 1.0:
         | 
| 119 | 
            +
                    return x
         | 
| 120 | 
            +
                if act_spec.cuda_idx is None:
         | 
| 121 | 
            +
                    return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                # CUDA kernel.
         | 
| 124 | 
            +
                cuda_kernel = _get_plugin().fused_bias_act
         | 
| 125 | 
            +
                cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                # Forward pass: y = func(x, b).
         | 
| 128 | 
            +
                def func_y(x, b):
         | 
| 129 | 
            +
                    y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs)
         | 
| 130 | 
            +
                    y.set_shape(x.shape)
         | 
| 131 | 
            +
                    return y
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                # Backward pass: dx, db = grad(dy, x, y)
         | 
| 134 | 
            +
                def grad_dx(dy, x, y):
         | 
| 135 | 
            +
                    ref = {'x': x, 'y': y}[act_spec.ref]
         | 
| 136 | 
            +
                    dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs)
         | 
| 137 | 
            +
                    dx.set_shape(x.shape)
         | 
| 138 | 
            +
                    return dx
         | 
| 139 | 
            +
                def grad_db(dx):
         | 
| 140 | 
            +
                    if b.shape[0] == 0:
         | 
| 141 | 
            +
                        return empty_tensor
         | 
| 142 | 
            +
                    db = dx
         | 
| 143 | 
            +
                    if axis < x.shape.rank - 1:
         | 
| 144 | 
            +
                        db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank)))
         | 
| 145 | 
            +
                    if axis > 0:
         | 
| 146 | 
            +
                        db = tf.reduce_sum(db, list(range(axis)))
         | 
| 147 | 
            +
                    db.set_shape(b.shape)
         | 
| 148 | 
            +
                    return db
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y)
         | 
| 151 | 
            +
                def grad2_d_dy(d_dx, d_db, x, y):
         | 
| 152 | 
            +
                    ref = {'x': x, 'y': y}[act_spec.ref]
         | 
| 153 | 
            +
                    d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs)
         | 
| 154 | 
            +
                    d_dy.set_shape(x.shape)
         | 
| 155 | 
            +
                    return d_dy
         | 
| 156 | 
            +
                def grad2_d_x(d_dx, d_db, x, y):
         | 
| 157 | 
            +
                    ref = {'x': x, 'y': y}[act_spec.ref]
         | 
| 158 | 
            +
                    d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs)
         | 
| 159 | 
            +
                    d_x.set_shape(x.shape)
         | 
| 160 | 
            +
                    return d_x
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                # Fast version for piecewise-linear activation funcs.
         | 
| 163 | 
            +
                @tf.custom_gradient
         | 
| 164 | 
            +
                def func_zero_2nd_grad(x, b):
         | 
| 165 | 
            +
                    y = func_y(x, b)
         | 
| 166 | 
            +
                    @tf.custom_gradient
         | 
| 167 | 
            +
                    def grad(dy):
         | 
| 168 | 
            +
                        dx = grad_dx(dy, x, y)
         | 
| 169 | 
            +
                        db = grad_db(dx)
         | 
| 170 | 
            +
                        def grad2(d_dx, d_db):
         | 
| 171 | 
            +
                            d_dy = grad2_d_dy(d_dx, d_db, x, y)
         | 
| 172 | 
            +
                            return d_dy
         | 
| 173 | 
            +
                        return (dx, db), grad2
         | 
| 174 | 
            +
                    return y, grad
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                # Slow version for general activation funcs.
         | 
| 177 | 
            +
                @tf.custom_gradient
         | 
| 178 | 
            +
                def func_nonzero_2nd_grad(x, b):
         | 
| 179 | 
            +
                    y = func_y(x, b)
         | 
| 180 | 
            +
                    def grad_wrap(dy):
         | 
| 181 | 
            +
                        @tf.custom_gradient
         | 
| 182 | 
            +
                        def grad_impl(dy, x):
         | 
| 183 | 
            +
                            dx = grad_dx(dy, x, y)
         | 
| 184 | 
            +
                            db = grad_db(dx)
         | 
| 185 | 
            +
                            def grad2(d_dx, d_db):
         | 
| 186 | 
            +
                                d_dy = grad2_d_dy(d_dx, d_db, x, y)
         | 
| 187 | 
            +
                                d_x = grad2_d_x(d_dx, d_db, x, y)
         | 
| 188 | 
            +
                                return d_dy, d_x
         | 
| 189 | 
            +
                            return (dx, db), grad2
         | 
| 190 | 
            +
                        return grad_impl(dy, x)
         | 
| 191 | 
            +
                    return y, grad_wrap
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                # Which version to use?
         | 
| 194 | 
            +
                if act_spec.zero_2nd_grad:
         | 
| 195 | 
            +
                    return func_zero_2nd_grad(x, b)
         | 
| 196 | 
            +
                return func_nonzero_2nd_grad(x, b)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
            #----------------------------------------------------------------------------
         | 
    	
        dnnlib/tflib/ops/upfirdn_2d.cu
    ADDED
    
    | @@ -0,0 +1,328 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            // Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 4 | 
            +
            //
         | 
| 5 | 
            +
            // This work is made available under the Nvidia Source Code License-NC.
         | 
| 6 | 
            +
            // To view a copy of this license, visit
         | 
| 7 | 
            +
            // https://nvlabs.github.io/stylegan2/license.html
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            #define EIGEN_USE_GPU
         | 
| 10 | 
            +
            #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
         | 
| 11 | 
            +
            #include "tensorflow/core/framework/op.h"
         | 
| 12 | 
            +
            #include "tensorflow/core/framework/op_kernel.h"
         | 
| 13 | 
            +
            #include "tensorflow/core/framework/shape_inference.h"
         | 
| 14 | 
            +
            #include <stdio.h>
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            using namespace tensorflow;
         | 
| 17 | 
            +
            using namespace tensorflow::shape_inference;
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            //------------------------------------------------------------------------
         | 
| 20 | 
            +
            // Helpers.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            static __host__ __device__ __forceinline__ int floorDiv(int a, int b)
         | 
| 25 | 
            +
            {
         | 
| 26 | 
            +
                int c = a / b;
         | 
| 27 | 
            +
                if (c * b > a)
         | 
| 28 | 
            +
                    c--;
         | 
| 29 | 
            +
                return c;
         | 
| 30 | 
            +
            }
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            //------------------------------------------------------------------------
         | 
| 33 | 
            +
            // CUDA kernel params.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            template <class T>
         | 
| 36 | 
            +
            struct UpFirDn2DKernelParams
         | 
| 37 | 
            +
            {
         | 
| 38 | 
            +
                const T*    x;          // [majorDim, inH, inW, minorDim]
         | 
| 39 | 
            +
                const T*    k;          // [kernelH, kernelW]
         | 
| 40 | 
            +
                T*          y;          // [majorDim, outH, outW, minorDim]
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                int         upx;
         | 
| 43 | 
            +
                int         upy;
         | 
| 44 | 
            +
                int         downx;
         | 
| 45 | 
            +
                int         downy;
         | 
| 46 | 
            +
                int         padx0;
         | 
| 47 | 
            +
                int         padx1;
         | 
| 48 | 
            +
                int         pady0;
         | 
| 49 | 
            +
                int         pady1;
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                int         majorDim;
         | 
| 52 | 
            +
                int         inH;
         | 
| 53 | 
            +
                int         inW;
         | 
| 54 | 
            +
                int         minorDim;
         | 
| 55 | 
            +
                int         kernelH;
         | 
| 56 | 
            +
                int         kernelW;
         | 
| 57 | 
            +
                int         outH;
         | 
| 58 | 
            +
                int         outW;
         | 
| 59 | 
            +
                int         loopMajor;
         | 
| 60 | 
            +
                int         loopX;
         | 
| 61 | 
            +
            };
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            //------------------------------------------------------------------------
         | 
| 64 | 
            +
            // General CUDA implementation for large filter kernels.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            template <class T>
         | 
| 67 | 
            +
            static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams<T> p)
         | 
| 68 | 
            +
            {
         | 
| 69 | 
            +
                // Calculate thread index.
         | 
| 70 | 
            +
                int minorIdx = blockIdx.x * blockDim.x + threadIdx.x;
         | 
| 71 | 
            +
                int outY = minorIdx / p.minorDim;
         | 
| 72 | 
            +
                minorIdx -= outY * p.minorDim;
         | 
| 73 | 
            +
                int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
         | 
| 74 | 
            +
                int majorIdxBase = blockIdx.z * p.loopMajor;
         | 
| 75 | 
            +
                if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim)
         | 
| 76 | 
            +
                    return;
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                // Setup Y receptive field.
         | 
| 79 | 
            +
                int midY = outY * p.downy + p.upy - 1 - p.pady0;
         | 
| 80 | 
            +
                int inY = min(max(floorDiv(midY, p.upy), 0), p.inH);
         | 
| 81 | 
            +
                int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY;
         | 
| 82 | 
            +
                int kernelY = midY + p.kernelH - (inY + 1) * p.upy;
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                // Loop over majorDim and outX.
         | 
| 85 | 
            +
                for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++)
         | 
| 86 | 
            +
                for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y)
         | 
| 87 | 
            +
                {
         | 
| 88 | 
            +
                    // Setup X receptive field.
         | 
| 89 | 
            +
                    int midX = outX * p.downx + p.upx - 1 - p.padx0;
         | 
| 90 | 
            +
                    int inX = min(max(floorDiv(midX, p.upx), 0), p.inW);
         | 
| 91 | 
            +
                    int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX;
         | 
| 92 | 
            +
                    int kernelX = midX + p.kernelW - (inX + 1) * p.upx;
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    // Initialize pointers.
         | 
| 95 | 
            +
                    const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
         | 
| 96 | 
            +
                    const T* kp = &p.k[kernelY * p.kernelW + kernelX];
         | 
| 97 | 
            +
                    int xpx = p.minorDim;
         | 
| 98 | 
            +
                    int kpx = -p.upx;
         | 
| 99 | 
            +
                    int xpy = p.inW * p.minorDim;
         | 
| 100 | 
            +
                    int kpy = -p.upy * p.kernelW;
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    // Inner loop.
         | 
| 103 | 
            +
                    float v = 0.0f;
         | 
| 104 | 
            +
                    for (int y = 0; y < h; y++)
         | 
| 105 | 
            +
                    {
         | 
| 106 | 
            +
                        for (int x = 0; x < w; x++)
         | 
| 107 | 
            +
                        {
         | 
| 108 | 
            +
                            v += (float)(*xp) * (float)(*kp);
         | 
| 109 | 
            +
                            xp += xpx;
         | 
| 110 | 
            +
                            kp += kpx;
         | 
| 111 | 
            +
                        }
         | 
| 112 | 
            +
                        xp += xpy - w * xpx;
         | 
| 113 | 
            +
                        kp += kpy - w * kpx;
         | 
| 114 | 
            +
                    }
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    // Store result.
         | 
| 117 | 
            +
                    p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
         | 
| 118 | 
            +
                }
         | 
| 119 | 
            +
            }
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            //------------------------------------------------------------------------
         | 
| 122 | 
            +
            // Specialized CUDA implementation for small filter kernels.
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            template <class T, int upx, int upy, int downx, int downy, int kernelW, int kernelH, int tileOutW, int tileOutH>
         | 
| 125 | 
            +
            static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams<T> p)
         | 
| 126 | 
            +
            {
         | 
| 127 | 
            +
                //assert(kernelW % upx == 0);
         | 
| 128 | 
            +
                //assert(kernelH % upy == 0);
         | 
| 129 | 
            +
                const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1;
         | 
| 130 | 
            +
                const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1;
         | 
| 131 | 
            +
                __shared__ volatile float sk[kernelH][kernelW];
         | 
| 132 | 
            +
                __shared__ volatile float sx[tileInH][tileInW];
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                // Calculate tile index.
         | 
| 135 | 
            +
                int minorIdx = blockIdx.x;
         | 
| 136 | 
            +
                int tileOutY = minorIdx / p.minorDim;
         | 
| 137 | 
            +
                minorIdx -= tileOutY * p.minorDim;
         | 
| 138 | 
            +
                tileOutY *= tileOutH;
         | 
| 139 | 
            +
                int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
         | 
| 140 | 
            +
                int majorIdxBase = blockIdx.z * p.loopMajor;
         | 
| 141 | 
            +
                if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim)
         | 
| 142 | 
            +
                    return;
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                // Load filter kernel (flipped).
         | 
| 145 | 
            +
                for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x)
         | 
| 146 | 
            +
                {
         | 
| 147 | 
            +
                    int ky = tapIdx / kernelW;
         | 
| 148 | 
            +
                    int kx = tapIdx - ky * kernelW;
         | 
| 149 | 
            +
                    float v = 0.0f;
         | 
| 150 | 
            +
                    if (kx < p.kernelW & ky < p.kernelH)
         | 
| 151 | 
            +
                        v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)];
         | 
| 152 | 
            +
                    sk[ky][kx] = v;
         | 
| 153 | 
            +
                }
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                // Loop over majorDim and outX.
         | 
| 156 | 
            +
                for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++)
         | 
| 157 | 
            +
                for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW)
         | 
| 158 | 
            +
                {
         | 
| 159 | 
            +
                    // Load input pixels.
         | 
| 160 | 
            +
                    int tileMidX = tileOutX * downx + upx - 1 - p.padx0;
         | 
| 161 | 
            +
                    int tileMidY = tileOutY * downy + upy - 1 - p.pady0;
         | 
| 162 | 
            +
                    int tileInX = floorDiv(tileMidX, upx);
         | 
| 163 | 
            +
                    int tileInY = floorDiv(tileMidY, upy);
         | 
| 164 | 
            +
                    __syncthreads();
         | 
| 165 | 
            +
                    for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x)
         | 
| 166 | 
            +
                    {
         | 
| 167 | 
            +
                        int relInY = inIdx / tileInW;
         | 
| 168 | 
            +
                        int relInX = inIdx - relInY * tileInW;
         | 
| 169 | 
            +
                        int inX = relInX + tileInX;
         | 
| 170 | 
            +
                        int inY = relInY + tileInY;
         | 
| 171 | 
            +
                        float v = 0.0f;
         | 
| 172 | 
            +
                        if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH)
         | 
| 173 | 
            +
                            v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
         | 
| 174 | 
            +
                        sx[relInY][relInX] = v;
         | 
| 175 | 
            +
                    }
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    // Loop over output pixels.
         | 
| 178 | 
            +
                    __syncthreads();
         | 
| 179 | 
            +
                    for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x)
         | 
| 180 | 
            +
                    {
         | 
| 181 | 
            +
                        int relOutY = outIdx / tileOutW;
         | 
| 182 | 
            +
                        int relOutX = outIdx - relOutY * tileOutW;
         | 
| 183 | 
            +
                        int outX = relOutX + tileOutX;
         | 
| 184 | 
            +
                        int outY = relOutY + tileOutY;
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                        // Setup receptive field.
         | 
| 187 | 
            +
                        int midX = tileMidX + relOutX * downx;
         | 
| 188 | 
            +
                        int midY = tileMidY + relOutY * downy;
         | 
| 189 | 
            +
                        int inX = floorDiv(midX, upx);
         | 
| 190 | 
            +
                        int inY = floorDiv(midY, upy);
         | 
| 191 | 
            +
                        int relInX = inX - tileInX;
         | 
| 192 | 
            +
                        int relInY = inY - tileInY;
         | 
| 193 | 
            +
                        int kernelX = (inX + 1) * upx - midX - 1; // flipped
         | 
| 194 | 
            +
                        int kernelY = (inY + 1) * upy - midY - 1; // flipped
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                        // Inner loop.
         | 
| 197 | 
            +
                        float v = 0.0f;
         | 
| 198 | 
            +
                        #pragma unroll
         | 
| 199 | 
            +
                        for (int y = 0; y < kernelH / upy; y++)
         | 
| 200 | 
            +
                            #pragma unroll
         | 
| 201 | 
            +
                            for (int x = 0; x < kernelW / upx; x++)
         | 
| 202 | 
            +
                                v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx];
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                        // Store result.
         | 
| 205 | 
            +
                        if (outX < p.outW & outY < p.outH)
         | 
| 206 | 
            +
                            p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
         | 
| 207 | 
            +
                    }
         | 
| 208 | 
            +
                }
         | 
| 209 | 
            +
            }
         | 
| 210 | 
            +
             | 
| 211 | 
            +
            //------------------------------------------------------------------------
         | 
| 212 | 
            +
            // TensorFlow op.
         | 
| 213 | 
            +
             | 
| 214 | 
            +
            template <class T>
         | 
| 215 | 
            +
            struct UpFirDn2DOp : public OpKernel
         | 
| 216 | 
            +
            {
         | 
| 217 | 
            +
                UpFirDn2DKernelParams<T> m_attribs;
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx)
         | 
| 220 | 
            +
                {
         | 
| 221 | 
            +
                    memset(&m_attribs, 0, sizeof(m_attribs));
         | 
| 222 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx));
         | 
| 223 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy));
         | 
| 224 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx));
         | 
| 225 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy));
         | 
| 226 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0));
         | 
| 227 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1));
         | 
| 228 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0));
         | 
| 229 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1));
         | 
| 230 | 
            +
                    OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1"));
         | 
| 231 | 
            +
                    OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1"));
         | 
| 232 | 
            +
                }
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                void Compute(OpKernelContext* ctx)
         | 
| 235 | 
            +
                {
         | 
| 236 | 
            +
                    UpFirDn2DKernelParams<T> p = m_attribs;
         | 
| 237 | 
            +
                    cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim]
         | 
| 240 | 
            +
                    const Tensor& k = ctx->input(1); // [kernelH, kernelW]
         | 
| 241 | 
            +
                    p.x = x.flat<T>().data();
         | 
| 242 | 
            +
                    p.k = k.flat<T>().data();
         | 
| 243 | 
            +
                    OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4"));
         | 
| 244 | 
            +
                    OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2"));
         | 
| 245 | 
            +
                    OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large"));
         | 
| 246 | 
            +
                    OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large"));
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    p.majorDim  = (int)x.dim_size(0);
         | 
| 249 | 
            +
                    p.inH       = (int)x.dim_size(1);
         | 
| 250 | 
            +
                    p.inW       = (int)x.dim_size(2);
         | 
| 251 | 
            +
                    p.minorDim  = (int)x.dim_size(3);
         | 
| 252 | 
            +
                    p.kernelH   = (int)k.dim_size(0);
         | 
| 253 | 
            +
                    p.kernelW   = (int)k.dim_size(1);
         | 
| 254 | 
            +
                    OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1"));
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx;
         | 
| 257 | 
            +
                    p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy;
         | 
| 258 | 
            +
                    OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1"));
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    Tensor* y = NULL; // [majorDim, outH, outW, minorDim]
         | 
| 261 | 
            +
                    TensorShape ys;
         | 
| 262 | 
            +
                    ys.AddDim(p.majorDim);
         | 
| 263 | 
            +
                    ys.AddDim(p.outH);
         | 
| 264 | 
            +
                    ys.AddDim(p.outW);
         | 
| 265 | 
            +
                    ys.AddDim(p.minorDim);
         | 
| 266 | 
            +
                    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y));
         | 
| 267 | 
            +
                    p.y = y->flat<T>().data();
         | 
| 268 | 
            +
                    OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large"));
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    // Choose CUDA kernel to use.
         | 
| 271 | 
            +
                    void* cudaKernel = (void*)UpFirDn2DKernel_large<T>;
         | 
| 272 | 
            +
                    int tileOutW = -1;
         | 
| 273 | 
            +
                    int tileOutH = -1;
         | 
| 274 | 
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 7,7, 64,16>; tileOutW = 64; tileOutH = 16; }
         | 
| 275 | 
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
         | 
| 276 | 
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 5,5, 64,16>; tileOutW = 64; tileOutH = 16; }
         | 
| 277 | 
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
         | 
| 278 | 
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 3,3, 64,16>; tileOutW = 64; tileOutH = 16; }
         | 
| 279 | 
            +
                    if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 8,8, 64,16>; tileOutW = 64; tileOutH = 16; }
         | 
| 280 | 
            +
                    if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
         | 
| 281 | 
            +
                    if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
         | 
| 282 | 
            +
                    if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 2,2, 64,16>; tileOutW = 64; tileOutH = 16; }
         | 
| 283 | 
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 8,8, 32,8>;  tileOutW = 32; tileOutH = 8;  }
         | 
| 284 | 
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 6,6, 32,8>;  tileOutW = 32; tileOutH = 8;  }
         | 
| 285 | 
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 4,4, 32,8>;  tileOutW = 32; tileOutH = 8;  }
         | 
| 286 | 
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 2,2, 32,8>;  tileOutW = 32; tileOutH = 8;  }
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    // Choose launch params.
         | 
| 289 | 
            +
                    dim3 blockSize;
         | 
| 290 | 
            +
                    dim3 gridSize;
         | 
| 291 | 
            +
                    if (tileOutW > 0 && tileOutH > 0) // small
         | 
| 292 | 
            +
                    {
         | 
| 293 | 
            +
                        p.loopMajor = (p.majorDim - 1) / 16384 + 1;
         | 
| 294 | 
            +
                        p.loopX = 1;
         | 
| 295 | 
            +
                        blockSize = dim3(32 * 8, 1, 1);
         | 
| 296 | 
            +
                        gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1);
         | 
| 297 | 
            +
                    }
         | 
| 298 | 
            +
                    else // large
         | 
| 299 | 
            +
                    {
         | 
| 300 | 
            +
                        p.loopMajor = (p.majorDim - 1) / 16384 + 1;
         | 
| 301 | 
            +
                        p.loopX = 4;
         | 
| 302 | 
            +
                        blockSize = dim3(4, 32, 1);
         | 
| 303 | 
            +
                        gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1);
         | 
| 304 | 
            +
                    }
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    // Launch CUDA kernel.
         | 
| 307 | 
            +
                    void* args[] = {&p};
         | 
| 308 | 
            +
                    OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream));
         | 
| 309 | 
            +
                }
         | 
| 310 | 
            +
            };
         | 
| 311 | 
            +
             | 
| 312 | 
            +
            REGISTER_OP("UpFirDn2D")
         | 
| 313 | 
            +
                .Input      ("x: T")
         | 
| 314 | 
            +
                .Input      ("k: T")
         | 
| 315 | 
            +
                .Output     ("y: T")
         | 
| 316 | 
            +
                .Attr       ("T: {float, half}")
         | 
| 317 | 
            +
                .Attr       ("upx: int = 1")
         | 
| 318 | 
            +
                .Attr       ("upy: int = 1")
         | 
| 319 | 
            +
                .Attr       ("downx: int = 1")
         | 
| 320 | 
            +
                .Attr       ("downy: int = 1")
         | 
| 321 | 
            +
                .Attr       ("padx0: int = 0")
         | 
| 322 | 
            +
                .Attr       ("padx1: int = 0")
         | 
| 323 | 
            +
                .Attr       ("pady0: int = 0")
         | 
| 324 | 
            +
                .Attr       ("pady1: int = 0");
         | 
| 325 | 
            +
            REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<float>("T"), UpFirDn2DOp<float>);
         | 
| 326 | 
            +
            REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), UpFirDn2DOp<Eigen::half>);
         | 
| 327 | 
            +
             | 
| 328 | 
            +
            //------------------------------------------------------------------------
         | 
    	
        dnnlib/tflib/ops/upfirdn_2d.py
    ADDED
    
    | @@ -0,0 +1,366 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # This work is made available under the Nvidia Source Code License-NC.
         | 
| 6 | 
            +
            # To view a copy of this license, visit
         | 
| 7 | 
            +
            # https://nvlabs.github.io/stylegan2/license.html
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """Custom TensorFlow ops for efficient resampling of 2D images."""
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import tensorflow as tf
         | 
| 14 | 
            +
            from .. import custom_ops
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def _get_plugin():
         | 
| 17 | 
            +
                return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            #----------------------------------------------------------------------------
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'):
         | 
| 22 | 
            +
                r"""Pad, upsample, FIR filter, and downsample a batch of 2D images.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]`
         | 
| 25 | 
            +
                and performs the following operations for each image, batched across
         | 
| 26 | 
            +
                `majorDim` and `minorDim`:
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                1. Pad the image with zeros by the specified number of pixels on each side
         | 
| 29 | 
            +
                   (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value
         | 
| 30 | 
            +
                   corresponds to cropping the image.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                2. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`).
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the
         | 
| 35 | 
            +
                   image so that the footprint of all output pixels lies within the input image.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                4. Downsample the image by throwing away pixels (`downx`, `downy`).
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                This sequence of operations bears close resemblance to scipy.signal.upfirdn().
         | 
| 40 | 
            +
                The fused op is considerably more efficient than performing the same calculation
         | 
| 41 | 
            +
                using standard TensorFlow ops. It supports gradients of arbitrary order.
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                Args:
         | 
| 44 | 
            +
                    x:      Input tensor of the shape `[majorDim, inH, inW, minorDim]`.
         | 
| 45 | 
            +
                    k:      2D FIR filter of the shape `[firH, firW]`.
         | 
| 46 | 
            +
                    upx:    Integer upsampling factor along the X-axis (default: 1).
         | 
| 47 | 
            +
                    upy:    Integer upsampling factor along the Y-axis (default: 1).
         | 
| 48 | 
            +
                    downx:  Integer downsampling factor along the X-axis (default: 1).
         | 
| 49 | 
            +
                    downy:  Integer downsampling factor along the Y-axis (default: 1).
         | 
| 50 | 
            +
                    padx0:  Number of pixels to pad on the left side (default: 0).
         | 
| 51 | 
            +
                    padx1:  Number of pixels to pad on the right side (default: 0).
         | 
| 52 | 
            +
                    pady0:  Number of pixels to pad on the top side (default: 0).
         | 
| 53 | 
            +
                    pady1:  Number of pixels to pad on the bottom side (default: 0).
         | 
| 54 | 
            +
                    impl:   Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                Returns:
         | 
| 57 | 
            +
                    Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`.
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                impl_dict = {
         | 
| 61 | 
            +
                    'ref':  _upfirdn_2d_ref,
         | 
| 62 | 
            +
                    'cuda': _upfirdn_2d_cuda,
         | 
| 63 | 
            +
                }
         | 
| 64 | 
            +
                return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            #----------------------------------------------------------------------------
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
         | 
| 69 | 
            +
                """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops."""
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                x = tf.convert_to_tensor(x)
         | 
| 72 | 
            +
                k = np.asarray(k, dtype=np.float32)
         | 
| 73 | 
            +
                assert x.shape.rank == 4
         | 
| 74 | 
            +
                inH = x.shape[1].value
         | 
| 75 | 
            +
                inW = x.shape[2].value
         | 
| 76 | 
            +
                minorDim = _shape(x, 3)
         | 
| 77 | 
            +
                kernelH, kernelW = k.shape
         | 
| 78 | 
            +
                assert inW >= 1 and inH >= 1
         | 
| 79 | 
            +
                assert kernelW >= 1 and kernelH >= 1
         | 
| 80 | 
            +
                assert isinstance(upx, int) and isinstance(upy, int)
         | 
| 81 | 
            +
                assert isinstance(downx, int) and isinstance(downy, int)
         | 
| 82 | 
            +
                assert isinstance(padx0, int) and isinstance(padx1, int)
         | 
| 83 | 
            +
                assert isinstance(pady0, int) and isinstance(pady1, int)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                # Upsample (insert zeros).
         | 
| 86 | 
            +
                x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
         | 
| 87 | 
            +
                x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
         | 
| 88 | 
            +
                x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                # Pad (crop if negative).
         | 
| 91 | 
            +
                x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]])
         | 
| 92 | 
            +
                x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :]
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                # Convolve with filter.
         | 
| 95 | 
            +
                x = tf.transpose(x, [0, 3, 1, 2])
         | 
| 96 | 
            +
                x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
         | 
| 97 | 
            +
                w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
         | 
| 98 | 
            +
                x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW')
         | 
| 99 | 
            +
                x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1])
         | 
| 100 | 
            +
                x = tf.transpose(x, [0, 2, 3, 1])
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                # Downsample (throw away pixels).
         | 
| 103 | 
            +
                return x[:, ::downy, ::downx, :]
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            #----------------------------------------------------------------------------
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
         | 
| 108 | 
            +
                """Fast CUDA implementation of `upfirdn_2d()` using custom ops."""
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                x = tf.convert_to_tensor(x)
         | 
| 111 | 
            +
                k = np.asarray(k, dtype=np.float32)
         | 
| 112 | 
            +
                majorDim, inH, inW, minorDim = x.shape.as_list()
         | 
| 113 | 
            +
                kernelH, kernelW = k.shape
         | 
| 114 | 
            +
                assert inW >= 1 and inH >= 1
         | 
| 115 | 
            +
                assert kernelW >= 1 and kernelH >= 1
         | 
| 116 | 
            +
                assert isinstance(upx, int) and isinstance(upy, int)
         | 
| 117 | 
            +
                assert isinstance(downx, int) and isinstance(downy, int)
         | 
| 118 | 
            +
                assert isinstance(padx0, int) and isinstance(padx1, int)
         | 
| 119 | 
            +
                assert isinstance(pady0, int) and isinstance(pady1, int)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1
         | 
| 122 | 
            +
                outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1
         | 
| 123 | 
            +
                assert outW >= 1 and outH >= 1
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                kc = tf.constant(k, dtype=x.dtype)
         | 
| 126 | 
            +
                gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype)
         | 
| 127 | 
            +
                gpadx0 = kernelW - padx0 - 1
         | 
| 128 | 
            +
                gpady0 = kernelH - pady0 - 1
         | 
| 129 | 
            +
                gpadx1 = inW * upx - outW * downx + padx0 - upx + 1
         | 
| 130 | 
            +
                gpady1 = inH * upy - outH * downy + pady0 - upy + 1
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                @tf.custom_gradient
         | 
| 133 | 
            +
                def func(x):
         | 
| 134 | 
            +
                    y = _get_plugin().up_fir_dn2d(x=x, k=kc, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
         | 
| 135 | 
            +
                    y.set_shape([majorDim, outH, outW, minorDim])
         | 
| 136 | 
            +
                    @tf.custom_gradient
         | 
| 137 | 
            +
                    def grad(dy):
         | 
| 138 | 
            +
                        dx = _get_plugin().up_fir_dn2d(x=dy, k=gkc, upx=downx, upy=downy, downx=upx, downy=upy, padx0=gpadx0, padx1=gpadx1, pady0=gpady0, pady1=gpady1)
         | 
| 139 | 
            +
                        dx.set_shape([majorDim, inH, inW, minorDim])
         | 
| 140 | 
            +
                        return dx, func
         | 
| 141 | 
            +
                    return y, grad
         | 
| 142 | 
            +
                return func(x)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            #----------------------------------------------------------------------------
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            def filter_2d(x, k, gain=1, data_format='NCHW', impl='cuda'):
         | 
| 147 | 
            +
                r"""Filter a batch of 2D images with the given FIR filter.
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
         | 
| 150 | 
            +
                and filters each image with the given filter. The filter is normalized so that
         | 
| 151 | 
            +
                if the input pixels are constant, they will be scaled by the specified `gain`.
         | 
| 152 | 
            +
                Pixels outside the image are assumed to be zero.
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                Args:
         | 
| 155 | 
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         | 
| 156 | 
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
         | 
| 157 | 
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         | 
| 158 | 
            +
                    data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
         | 
| 159 | 
            +
                    impl:         Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                Returns:
         | 
| 162 | 
            +
                    Tensor of the same shape and datatype as `x`.
         | 
| 163 | 
            +
                """
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                k = _setup_kernel(k) * gain
         | 
| 166 | 
            +
                p = k.shape[0] - 1
         | 
| 167 | 
            +
                return _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
            #----------------------------------------------------------------------------
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            def upsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
         | 
| 172 | 
            +
                r"""Upsample a batch of 2D images with the given filter.
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
         | 
| 175 | 
            +
                and upsamples each image with the given filter. The filter is normalized so that
         | 
| 176 | 
            +
                if the input pixels are constant, they will be scaled by the specified `gain`.
         | 
| 177 | 
            +
                Pixels outside the image are assumed to be zero, and the filter is padded with
         | 
| 178 | 
            +
                zeros so that its shape is a multiple of the upsampling factor.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                Args:
         | 
| 181 | 
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         | 
| 182 | 
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
         | 
| 183 | 
            +
                                  The default is `[1] * factor`, which corresponds to nearest-neighbor
         | 
| 184 | 
            +
                                  upsampling.
         | 
| 185 | 
            +
                    factor:       Integer upsampling factor (default: 2).
         | 
| 186 | 
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         | 
| 187 | 
            +
                    data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
         | 
| 188 | 
            +
                    impl:         Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                Returns:
         | 
| 191 | 
            +
                    Tensor of the shape `[N, C, H * factor, W * factor]` or
         | 
| 192 | 
            +
                    `[N, H * factor, W * factor, C]`, and same datatype as `x`.
         | 
| 193 | 
            +
                """
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                assert isinstance(factor, int) and factor >= 1
         | 
| 196 | 
            +
                if k is None:
         | 
| 197 | 
            +
                    k = [1] * factor
         | 
| 198 | 
            +
                k = _setup_kernel(k) * (gain * (factor ** 2))
         | 
| 199 | 
            +
                p = k.shape[0] - factor
         | 
| 200 | 
            +
                return _simple_upfirdn_2d(x, k, up=factor, pad0=(p+1)//2+factor-1, pad1=p//2, data_format=data_format, impl=impl)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
            #----------------------------------------------------------------------------
         | 
| 203 | 
            +
             | 
| 204 | 
            +
            def downsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
         | 
| 205 | 
            +
                r"""Downsample a batch of 2D images with the given filter.
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
         | 
| 208 | 
            +
                and downsamples each image with the given filter. The filter is normalized so that
         | 
| 209 | 
            +
                if the input pixels are constant, they will be scaled by the specified `gain`.
         | 
| 210 | 
            +
                Pixels outside the image are assumed to be zero, and the filter is padded with
         | 
| 211 | 
            +
                zeros so that its shape is a multiple of the downsampling factor.
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                Args:
         | 
| 214 | 
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         | 
| 215 | 
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
         | 
| 216 | 
            +
                                  The default is `[1] * factor`, which corresponds to average pooling.
         | 
| 217 | 
            +
                    factor:       Integer downsampling factor (default: 2).
         | 
| 218 | 
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         | 
| 219 | 
            +
                    data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
         | 
| 220 | 
            +
                    impl:         Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                Returns:
         | 
| 223 | 
            +
                    Tensor of the shape `[N, C, H // factor, W // factor]` or
         | 
| 224 | 
            +
                    `[N, H // factor, W // factor, C]`, and same datatype as `x`.
         | 
| 225 | 
            +
                """
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                assert isinstance(factor, int) and factor >= 1
         | 
| 228 | 
            +
                if k is None:
         | 
| 229 | 
            +
                    k = [1] * factor
         | 
| 230 | 
            +
                k = _setup_kernel(k) * gain
         | 
| 231 | 
            +
                p = k.shape[0] - factor
         | 
| 232 | 
            +
                return _simple_upfirdn_2d(x, k, down=factor, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
            #----------------------------------------------------------------------------
         | 
| 235 | 
            +
             | 
| 236 | 
            +
            def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
         | 
| 237 | 
            +
                r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                Padding is performed only once at the beginning, not between the operations.
         | 
| 240 | 
            +
                The fused op is considerably more efficient than performing the same calculation
         | 
| 241 | 
            +
                using standard TensorFlow ops. It supports gradients of arbitrary order.
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                Args:
         | 
| 244 | 
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         | 
| 245 | 
            +
                    w:            Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
         | 
| 246 | 
            +
                                  Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
         | 
| 247 | 
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
         | 
| 248 | 
            +
                                  The default is `[1] * factor`, which corresponds to nearest-neighbor
         | 
| 249 | 
            +
                                  upsampling.
         | 
| 250 | 
            +
                    factor:       Integer upsampling factor (default: 2).
         | 
| 251 | 
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         | 
| 252 | 
            +
                    data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
         | 
| 253 | 
            +
                    impl:         Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                Returns:
         | 
| 256 | 
            +
                    Tensor of the shape `[N, C, H * factor, W * factor]` or
         | 
| 257 | 
            +
                    `[N, H * factor, W * factor, C]`, and same datatype as `x`.
         | 
| 258 | 
            +
                """
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                assert isinstance(factor, int) and factor >= 1
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                # Check weight shape.
         | 
| 263 | 
            +
                w = tf.convert_to_tensor(w)
         | 
| 264 | 
            +
                assert w.shape.rank == 4
         | 
| 265 | 
            +
                convH = w.shape[0].value
         | 
| 266 | 
            +
                convW = w.shape[1].value
         | 
| 267 | 
            +
                inC = _shape(w, 2)
         | 
| 268 | 
            +
                outC = _shape(w, 3)
         | 
| 269 | 
            +
                assert convW == convH
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                # Setup filter kernel.
         | 
| 272 | 
            +
                if k is None:
         | 
| 273 | 
            +
                    k = [1] * factor
         | 
| 274 | 
            +
                k = _setup_kernel(k) * (gain * (factor ** 2))
         | 
| 275 | 
            +
                p = (k.shape[0] - factor) - (convW - 1)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                # Determine data dimensions.
         | 
| 278 | 
            +
                if data_format == 'NCHW':
         | 
| 279 | 
            +
                    stride = [1, 1, factor, factor]
         | 
| 280 | 
            +
                    output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW]
         | 
| 281 | 
            +
                    num_groups = _shape(x, 1) // inC
         | 
| 282 | 
            +
                else:
         | 
| 283 | 
            +
                    stride = [1, factor, factor, 1]
         | 
| 284 | 
            +
                    output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + convH, (_shape(x, 2) - 1) * factor + convW, outC]
         | 
| 285 | 
            +
                    num_groups = _shape(x, 3) // inC
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                # Transpose weights.
         | 
| 288 | 
            +
                w = tf.reshape(w, [convH, convW, inC, num_groups, -1])
         | 
| 289 | 
            +
                w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
         | 
| 290 | 
            +
                w = tf.reshape(w, [convH, convW, -1, num_groups * inC])
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                # Execute.
         | 
| 293 | 
            +
                x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format)
         | 
| 294 | 
            +
                return _simple_upfirdn_2d(x, k, pad0=(p+1)//2+factor-1, pad1=p//2+1, data_format=data_format, impl=impl)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
            #----------------------------------------------------------------------------
         | 
| 297 | 
            +
             | 
| 298 | 
            +
            def conv_downsample_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
         | 
| 299 | 
            +
                r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                Padding is performed only once at the beginning, not between the operations.
         | 
| 302 | 
            +
                The fused op is considerably more efficient than performing the same calculation
         | 
| 303 | 
            +
                using standard TensorFlow ops. It supports gradients of arbitrary order.
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                Args:
         | 
| 306 | 
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         | 
| 307 | 
            +
                    w:            Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
         | 
| 308 | 
            +
                                  Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
         | 
| 309 | 
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
         | 
| 310 | 
            +
                                  The default is `[1] * factor`, which corresponds to average pooling.
         | 
| 311 | 
            +
                    factor:       Integer downsampling factor (default: 2).
         | 
| 312 | 
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         | 
| 313 | 
            +
                    data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
         | 
| 314 | 
            +
                    impl:         Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                Returns:
         | 
| 317 | 
            +
                    Tensor of the shape `[N, C, H // factor, W // factor]` or
         | 
| 318 | 
            +
                    `[N, H // factor, W // factor, C]`, and same datatype as `x`.
         | 
| 319 | 
            +
                """
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                assert isinstance(factor, int) and factor >= 1
         | 
| 322 | 
            +
                w = tf.convert_to_tensor(w)
         | 
| 323 | 
            +
                convH, convW, _inC, _outC = w.shape.as_list()
         | 
| 324 | 
            +
                assert convW == convH
         | 
| 325 | 
            +
                if k is None:
         | 
| 326 | 
            +
                    k = [1] * factor
         | 
| 327 | 
            +
                k = _setup_kernel(k) * gain
         | 
| 328 | 
            +
                p = (k.shape[0] - factor) + (convW - 1)
         | 
| 329 | 
            +
                if data_format == 'NCHW':
         | 
| 330 | 
            +
                    s = [1, 1, factor, factor]
         | 
| 331 | 
            +
                else:
         | 
| 332 | 
            +
                    s = [1, factor, factor, 1]
         | 
| 333 | 
            +
                x = _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
         | 
| 334 | 
            +
                return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format)
         | 
| 335 | 
            +
             | 
| 336 | 
            +
            #----------------------------------------------------------------------------
         | 
| 337 | 
            +
            # Internal helper funcs.
         | 
| 338 | 
            +
             | 
| 339 | 
            +
            def _shape(tf_expr, dim_idx):
         | 
| 340 | 
            +
                if tf_expr.shape.rank is not None:
         | 
| 341 | 
            +
                    dim = tf_expr.shape[dim_idx].value
         | 
| 342 | 
            +
                    if dim is not None:
         | 
| 343 | 
            +
                        return dim
         | 
| 344 | 
            +
                return tf.shape(tf_expr)[dim_idx]
         | 
| 345 | 
            +
             | 
| 346 | 
            +
            def _setup_kernel(k):
         | 
| 347 | 
            +
                k = np.asarray(k, dtype=np.float32)
         | 
| 348 | 
            +
                if k.ndim == 1:
         | 
| 349 | 
            +
                    k = np.outer(k, k)
         | 
| 350 | 
            +
                k /= np.sum(k)
         | 
| 351 | 
            +
                assert k.ndim == 2
         | 
| 352 | 
            +
                assert k.shape[0] == k.shape[1]
         | 
| 353 | 
            +
                return k
         | 
| 354 | 
            +
             | 
| 355 | 
            +
            def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'):
         | 
| 356 | 
            +
                assert data_format in ['NCHW', 'NHWC']
         | 
| 357 | 
            +
                assert x.shape.rank == 4
         | 
| 358 | 
            +
                y = x
         | 
| 359 | 
            +
                if data_format == 'NCHW':
         | 
| 360 | 
            +
                    y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1])
         | 
| 361 | 
            +
                y = upfirdn_2d(y, k, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl)
         | 
| 362 | 
            +
                if data_format == 'NCHW':
         | 
| 363 | 
            +
                    y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)])
         | 
| 364 | 
            +
                return y
         | 
| 365 | 
            +
             | 
| 366 | 
            +
            #----------------------------------------------------------------------------
         | 
    	
        dnnlib/tflib/optimizer.py
    ADDED
    
    | @@ -0,0 +1,338 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # This work is made available under the Nvidia Source Code License-NC.
         | 
| 6 | 
            +
            # To view a copy of this license, visit
         | 
| 7 | 
            +
            # https://nvlabs.github.io/stylegan2/license.html
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """Helper wrapper for a Tensorflow optimizer."""
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import tensorflow as tf
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from collections import OrderedDict
         | 
| 15 | 
            +
            from typing import List, Union
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from . import autosummary
         | 
| 18 | 
            +
            from . import tfutil
         | 
| 19 | 
            +
            from .. import util
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from .tfutil import TfExpression, TfExpressionEx
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            try:
         | 
| 24 | 
            +
                # TensorFlow 1.13
         | 
| 25 | 
            +
                from tensorflow.python.ops import nccl_ops
         | 
| 26 | 
            +
            except:
         | 
| 27 | 
            +
                # Older TensorFlow versions
         | 
| 28 | 
            +
                import tensorflow.contrib.nccl as nccl_ops
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            class Optimizer:
         | 
| 31 | 
            +
                """A Wrapper for tf.train.Optimizer.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                Automatically takes care of:
         | 
| 34 | 
            +
                - Gradient averaging for multi-GPU training.
         | 
| 35 | 
            +
                - Gradient accumulation for arbitrarily large minibatches.
         | 
| 36 | 
            +
                - Dynamic loss scaling and typecasts for FP16 training.
         | 
| 37 | 
            +
                - Ignoring corrupted gradients that contain NaNs/Infs.
         | 
| 38 | 
            +
                - Reporting statistics.
         | 
| 39 | 
            +
                - Well-chosen default settings.
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def __init__(self,
         | 
| 43 | 
            +
                    name:                   str             = "Train",                  # Name string that will appear in TensorFlow graph.
         | 
| 44 | 
            +
                    tf_optimizer:           str             = "tf.train.AdamOptimizer", # Underlying optimizer class.
         | 
| 45 | 
            +
                    learning_rate:          TfExpressionEx  = 0.001,                    # Learning rate. Can vary over time.
         | 
| 46 | 
            +
                    minibatch_multiplier:   TfExpressionEx  = None,                     # Treat N consecutive minibatches as one by accumulating gradients.
         | 
| 47 | 
            +
                    share:                  "Optimizer"     = None,                     # Share internal state with a previously created optimizer?
         | 
| 48 | 
            +
                    use_loss_scaling:       bool            = False,                    # Enable dynamic loss scaling for robust mixed-precision training?
         | 
| 49 | 
            +
                    loss_scaling_init:      float           = 64.0,                     # Log2 of initial loss scaling factor.
         | 
| 50 | 
            +
                    loss_scaling_inc:       float           = 0.0005,                   # Log2 of per-minibatch loss scaling increment when there is no overflow.
         | 
| 51 | 
            +
                    loss_scaling_dec:       float           = 1.0,                      # Log2 of per-minibatch loss scaling decrement when there is an overflow.
         | 
| 52 | 
            +
                    report_mem_usage:       bool            = False,                    # Report fine-grained memory usage statistics in TensorBoard?
         | 
| 53 | 
            +
                    **kwargs):
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # Public fields.
         | 
| 56 | 
            +
                    self.name                   = name
         | 
| 57 | 
            +
                    self.learning_rate          = learning_rate
         | 
| 58 | 
            +
                    self.minibatch_multiplier   = minibatch_multiplier
         | 
| 59 | 
            +
                    self.id                     = self.name.replace("/", ".")
         | 
| 60 | 
            +
                    self.scope                  = tf.get_default_graph().unique_name(self.id)
         | 
| 61 | 
            +
                    self.optimizer_class        = util.get_obj_by_name(tf_optimizer)
         | 
| 62 | 
            +
                    self.optimizer_kwargs       = dict(kwargs)
         | 
| 63 | 
            +
                    self.use_loss_scaling       = use_loss_scaling
         | 
| 64 | 
            +
                    self.loss_scaling_init      = loss_scaling_init
         | 
| 65 | 
            +
                    self.loss_scaling_inc       = loss_scaling_inc
         | 
| 66 | 
            +
                    self.loss_scaling_dec       = loss_scaling_dec
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # Private fields.
         | 
| 69 | 
            +
                    self._updates_applied       = False
         | 
| 70 | 
            +
                    self._devices               = OrderedDict() # device_name => EasyDict()
         | 
| 71 | 
            +
                    self._shared_optimizers     = OrderedDict() # device_name => optimizer_class
         | 
| 72 | 
            +
                    self._gradient_shapes       = None          # [shape, ...]
         | 
| 73 | 
            +
                    self._report_mem_usage      = report_mem_usage
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # Validate arguments.
         | 
| 76 | 
            +
                    assert callable(self.optimizer_class)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    # Share internal state if requested.
         | 
| 79 | 
            +
                    if share is not None:
         | 
| 80 | 
            +
                        assert isinstance(share, Optimizer)
         | 
| 81 | 
            +
                        assert self.optimizer_class is share.optimizer_class
         | 
| 82 | 
            +
                        assert self.learning_rate is share.learning_rate
         | 
| 83 | 
            +
                        assert self.optimizer_kwargs == share.optimizer_kwargs
         | 
| 84 | 
            +
                        self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def _get_device(self, device_name: str):
         | 
| 87 | 
            +
                    """Get internal state for the given TensorFlow device."""
         | 
| 88 | 
            +
                    tfutil.assert_tf_initialized()
         | 
| 89 | 
            +
                    if device_name in self._devices:
         | 
| 90 | 
            +
                        return self._devices[device_name]
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    # Initialize fields.
         | 
| 93 | 
            +
                    device = util.EasyDict()
         | 
| 94 | 
            +
                    device.name             = device_name
         | 
| 95 | 
            +
                    device.optimizer        = None          # Underlying optimizer:     optimizer_class
         | 
| 96 | 
            +
                    device.loss_scaling_var = None          # Log2 of loss scaling:     tf.Variable
         | 
| 97 | 
            +
                    device.grad_raw         = OrderedDict() # Raw gradients:            var => [grad, ...]
         | 
| 98 | 
            +
                    device.grad_clean       = OrderedDict() # Clean gradients:          var => grad
         | 
| 99 | 
            +
                    device.grad_acc_vars    = OrderedDict() # Accumulation sums:        var => tf.Variable
         | 
| 100 | 
            +
                    device.grad_acc_count   = None          # Accumulation counter:     tf.Variable
         | 
| 101 | 
            +
                    device.grad_acc         = OrderedDict() # Accumulated gradients:    var => grad
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    # Setup TensorFlow objects.
         | 
| 104 | 
            +
                    with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
         | 
| 105 | 
            +
                        if device_name not in self._shared_optimizers:
         | 
| 106 | 
            +
                            optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
         | 
| 107 | 
            +
                            self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
         | 
| 108 | 
            +
                        device.optimizer = self._shared_optimizers[device_name]
         | 
| 109 | 
            +
                        if self.use_loss_scaling:
         | 
| 110 | 
            +
                            device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    # Register device.
         | 
| 113 | 
            +
                    self._devices[device_name] = device
         | 
| 114 | 
            +
                    return device
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
         | 
| 117 | 
            +
                    """Register the gradients of the given loss function with respect to the given variables.
         | 
| 118 | 
            +
                    Intended to be called once per GPU."""
         | 
| 119 | 
            +
                    tfutil.assert_tf_initialized()
         | 
| 120 | 
            +
                    assert not self._updates_applied
         | 
| 121 | 
            +
                    device = self._get_device(loss.device)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # Validate trainables.
         | 
| 124 | 
            +
                    if isinstance(trainable_vars, dict):
         | 
| 125 | 
            +
                        trainable_vars = list(trainable_vars.values())  # allow passing in Network.trainables as vars
         | 
| 126 | 
            +
                    assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
         | 
| 127 | 
            +
                    assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
         | 
| 128 | 
            +
                    assert all(var.device == device.name for var in trainable_vars)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # Validate shapes.
         | 
| 131 | 
            +
                    if self._gradient_shapes is None:
         | 
| 132 | 
            +
                        self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
         | 
| 133 | 
            +
                    assert len(trainable_vars) == len(self._gradient_shapes)
         | 
| 134 | 
            +
                    assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    # Report memory usage if requested.
         | 
| 137 | 
            +
                    deps = []
         | 
| 138 | 
            +
                    if self._report_mem_usage:
         | 
| 139 | 
            +
                        self._report_mem_usage = False
         | 
| 140 | 
            +
                        try:
         | 
| 141 | 
            +
                            with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
         | 
| 142 | 
            +
                                deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
         | 
| 143 | 
            +
                        except tf.errors.NotFoundError:
         | 
| 144 | 
            +
                            pass
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    # Compute gradients.
         | 
| 147 | 
            +
                    with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
         | 
| 148 | 
            +
                        loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
         | 
| 149 | 
            +
                        gate = tf.train.Optimizer.GATE_NONE  # disable gating to reduce memory usage
         | 
| 150 | 
            +
                        grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    # Register gradients.
         | 
| 153 | 
            +
                    for grad, var in grad_list:
         | 
| 154 | 
            +
                        if var not in device.grad_raw:
         | 
| 155 | 
            +
                            device.grad_raw[var] = []
         | 
| 156 | 
            +
                        device.grad_raw[var].append(grad)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
         | 
| 159 | 
            +
                    """Construct training op to update the registered variables based on their gradients."""
         | 
| 160 | 
            +
                    tfutil.assert_tf_initialized()
         | 
| 161 | 
            +
                    assert not self._updates_applied
         | 
| 162 | 
            +
                    self._updates_applied = True
         | 
| 163 | 
            +
                    all_ops = []
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    # Check for no-op.
         | 
| 166 | 
            +
                    if allow_no_op and len(self._devices) == 0:
         | 
| 167 | 
            +
                        with tfutil.absolute_name_scope(self.scope):
         | 
| 168 | 
            +
                            return tf.no_op(name='TrainingOp')
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    # Clean up gradients.
         | 
| 171 | 
            +
                    for device_idx, device in enumerate(self._devices.values()):
         | 
| 172 | 
            +
                        with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
         | 
| 173 | 
            +
                            for var, grad in device.grad_raw.items():
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                                # Filter out disconnected gradients and convert to float32.
         | 
| 176 | 
            +
                                grad = [g for g in grad if g is not None]
         | 
| 177 | 
            +
                                grad = [tf.cast(g, tf.float32) for g in grad]
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                                # Sum within the device.
         | 
| 180 | 
            +
                                if len(grad) == 0:
         | 
| 181 | 
            +
                                    grad = tf.zeros(var.shape)  # No gradients => zero.
         | 
| 182 | 
            +
                                elif len(grad) == 1:
         | 
| 183 | 
            +
                                    grad = grad[0]              # Single gradient => use as is.
         | 
| 184 | 
            +
                                else:
         | 
| 185 | 
            +
                                    grad = tf.add_n(grad)       # Multiple gradients => sum.
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                                # Scale as needed.
         | 
| 188 | 
            +
                                scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
         | 
| 189 | 
            +
                                scale = tf.constant(scale, dtype=tf.float32, name="scale")
         | 
| 190 | 
            +
                                if self.minibatch_multiplier is not None:
         | 
| 191 | 
            +
                                    scale /= tf.cast(self.minibatch_multiplier, tf.float32)
         | 
| 192 | 
            +
                                scale = self.undo_loss_scaling(scale)
         | 
| 193 | 
            +
                                device.grad_clean[var] = grad * scale
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    # Sum gradients across devices.
         | 
| 196 | 
            +
                    if len(self._devices) > 1:
         | 
| 197 | 
            +
                        with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
         | 
| 198 | 
            +
                            for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
         | 
| 199 | 
            +
                                if len(all_vars) > 0 and all(dim > 0 for dim in all_vars[0].shape.as_list()): # NCCL does not support zero-sized tensors.
         | 
| 200 | 
            +
                                    all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
         | 
| 201 | 
            +
                                    all_grads = nccl_ops.all_sum(all_grads)
         | 
| 202 | 
            +
                                    for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
         | 
| 203 | 
            +
                                        device.grad_clean[var] = grad
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # Apply updates separately on each device.
         | 
| 206 | 
            +
                    for device_idx, device in enumerate(self._devices.values()):
         | 
| 207 | 
            +
                        with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
         | 
| 208 | 
            +
                            # pylint: disable=cell-var-from-loop
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                            # Accumulate gradients over time.
         | 
| 211 | 
            +
                            if self.minibatch_multiplier is None:
         | 
| 212 | 
            +
                                acc_ok = tf.constant(True, name='acc_ok')
         | 
| 213 | 
            +
                                device.grad_acc = OrderedDict(device.grad_clean)
         | 
| 214 | 
            +
                            else:
         | 
| 215 | 
            +
                                # Create variables.
         | 
| 216 | 
            +
                                with tf.control_dependencies(None):
         | 
| 217 | 
            +
                                    for var in device.grad_clean.keys():
         | 
| 218 | 
            +
                                        device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
         | 
| 219 | 
            +
                                    device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                                # Track counter.
         | 
| 222 | 
            +
                                count_cur = device.grad_acc_count + 1.0
         | 
| 223 | 
            +
                                count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
         | 
| 224 | 
            +
                                count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
         | 
| 225 | 
            +
                                acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
         | 
| 226 | 
            +
                                all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                                # Track gradients.
         | 
| 229 | 
            +
                                for var, grad in device.grad_clean.items():
         | 
| 230 | 
            +
                                    acc_var = device.grad_acc_vars[var]
         | 
| 231 | 
            +
                                    acc_cur = acc_var + grad
         | 
| 232 | 
            +
                                    device.grad_acc[var] = acc_cur
         | 
| 233 | 
            +
                                    with tf.control_dependencies([acc_cur]):
         | 
| 234 | 
            +
                                        acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
         | 
| 235 | 
            +
                                        acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
         | 
| 236 | 
            +
                                        all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                            # No overflow => apply gradients.
         | 
| 239 | 
            +
                            all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
         | 
| 240 | 
            +
                            apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
         | 
| 241 | 
            +
                            all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                            # Adjust loss scaling.
         | 
| 244 | 
            +
                            if self.use_loss_scaling:
         | 
| 245 | 
            +
                                ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
         | 
| 246 | 
            +
                                ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
         | 
| 247 | 
            +
                                ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
         | 
| 248 | 
            +
                                all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                            # Last device => report statistics.
         | 
| 251 | 
            +
                            if device_idx == len(self._devices) - 1:
         | 
| 252 | 
            +
                                all_ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
         | 
| 253 | 
            +
                                all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
         | 
| 254 | 
            +
                                if self.use_loss_scaling:
         | 
| 255 | 
            +
                                    all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    # Initialize variables.
         | 
| 258 | 
            +
                    self.reset_optimizer_state()
         | 
| 259 | 
            +
                    if self.use_loss_scaling:
         | 
| 260 | 
            +
                        tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
         | 
| 261 | 
            +
                    if self.minibatch_multiplier is not None:
         | 
| 262 | 
            +
                        tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    # Group everything into a single op.
         | 
| 265 | 
            +
                    with tfutil.absolute_name_scope(self.scope):
         | 
| 266 | 
            +
                        return tf.group(*all_ops, name="TrainingOp")
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                def reset_optimizer_state(self) -> None:
         | 
| 269 | 
            +
                    """Reset internal state of the underlying optimizer."""
         | 
| 270 | 
            +
                    tfutil.assert_tf_initialized()
         | 
| 271 | 
            +
                    tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
         | 
| 274 | 
            +
                    """Get or create variable representing log2 of the current dynamic loss scaling factor."""
         | 
| 275 | 
            +
                    return self._get_device(device).loss_scaling_var
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
         | 
| 278 | 
            +
                    """Apply dynamic loss scaling for the given expression."""
         | 
| 279 | 
            +
                    assert tfutil.is_tf_expression(value)
         | 
| 280 | 
            +
                    if not self.use_loss_scaling:
         | 
| 281 | 
            +
                        return value
         | 
| 282 | 
            +
                    return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
         | 
| 285 | 
            +
                    """Undo the effect of dynamic loss scaling for the given expression."""
         | 
| 286 | 
            +
                    assert tfutil.is_tf_expression(value)
         | 
| 287 | 
            +
                    if not self.use_loss_scaling:
         | 
| 288 | 
            +
                        return value
         | 
| 289 | 
            +
                    return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
         | 
| 290 | 
            +
             | 
| 291 | 
            +
             | 
| 292 | 
            +
            class SimpleAdam:
         | 
| 293 | 
            +
                """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
         | 
| 296 | 
            +
                    self.name = name
         | 
| 297 | 
            +
                    self.learning_rate = learning_rate
         | 
| 298 | 
            +
                    self.beta1 = beta1
         | 
| 299 | 
            +
                    self.beta2 = beta2
         | 
| 300 | 
            +
                    self.epsilon = epsilon
         | 
| 301 | 
            +
                    self.all_state_vars = []
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                def variables(self):
         | 
| 304 | 
            +
                    return self.all_state_vars
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
         | 
| 307 | 
            +
                    assert gate_gradients == tf.train.Optimizer.GATE_NONE
         | 
| 308 | 
            +
                    return list(zip(tf.gradients(loss, var_list), var_list))
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                def apply_gradients(self, grads_and_vars):
         | 
| 311 | 
            +
                    with tf.name_scope(self.name):
         | 
| 312 | 
            +
                        state_vars = []
         | 
| 313 | 
            +
                        update_ops = []
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                        # Adjust learning rate to deal with startup bias.
         | 
| 316 | 
            +
                        with tf.control_dependencies(None):
         | 
| 317 | 
            +
                            b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
         | 
| 318 | 
            +
                            b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
         | 
| 319 | 
            +
                            state_vars += [b1pow_var, b2pow_var]
         | 
| 320 | 
            +
                        b1pow_new = b1pow_var * self.beta1
         | 
| 321 | 
            +
                        b2pow_new = b2pow_var * self.beta2
         | 
| 322 | 
            +
                        update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
         | 
| 323 | 
            +
                        lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                        # Construct ops to update each variable.
         | 
| 326 | 
            +
                        for grad, var in grads_and_vars:
         | 
| 327 | 
            +
                            with tf.control_dependencies(None):
         | 
| 328 | 
            +
                                m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
         | 
| 329 | 
            +
                                v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
         | 
| 330 | 
            +
                                state_vars += [m_var, v_var]
         | 
| 331 | 
            +
                            m_new = self.beta1 * m_var + (1 - self.beta1) * grad
         | 
| 332 | 
            +
                            v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
         | 
| 333 | 
            +
                            var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
         | 
| 334 | 
            +
                            update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                        # Group everything together.
         | 
| 337 | 
            +
                        self.all_state_vars += state_vars
         | 
| 338 | 
            +
                        return tf.group(*update_ops)
         | 
    	
        dnnlib/tflib/tfutil.py
    ADDED
    
    | @@ -0,0 +1,254 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # This work is made available under the Nvidia Source Code License-NC.
         | 
| 6 | 
            +
            # To view a copy of this license, visit
         | 
| 7 | 
            +
            # https://nvlabs.github.io/stylegan2/license.html
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """Miscellaneous helper utils for Tensorflow."""
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import tensorflow as tf
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # Silence deprecation warnings from TensorFlow 1.13 onwards
         | 
| 16 | 
            +
            import logging
         | 
| 17 | 
            +
            logging.getLogger('tensorflow').setLevel(logging.ERROR)
         | 
| 18 | 
            +
            import tensorflow.contrib   # requires TensorFlow 1.x!
         | 
| 19 | 
            +
            tf.contrib = tensorflow.contrib
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from typing import Any, Iterable, List, Union
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
         | 
| 24 | 
            +
            """A type that represents a valid Tensorflow expression."""
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
         | 
| 27 | 
            +
            """A type that can be converted to a valid Tensorflow expression."""
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def run(*args, **kwargs) -> Any:
         | 
| 31 | 
            +
                """Run the specified ops in the default session."""
         | 
| 32 | 
            +
                assert_tf_initialized()
         | 
| 33 | 
            +
                return tf.get_default_session().run(*args, **kwargs)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def is_tf_expression(x: Any) -> bool:
         | 
| 37 | 
            +
                """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
         | 
| 38 | 
            +
                return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
         | 
| 42 | 
            +
                """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
         | 
| 43 | 
            +
                return [dim.value for dim in shape]
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def flatten(x: TfExpressionEx) -> TfExpression:
         | 
| 47 | 
            +
                """Shortcut function for flattening a tensor."""
         | 
| 48 | 
            +
                with tf.name_scope("Flatten"):
         | 
| 49 | 
            +
                    return tf.reshape(x, [-1])
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def log2(x: TfExpressionEx) -> TfExpression:
         | 
| 53 | 
            +
                """Logarithm in base 2."""
         | 
| 54 | 
            +
                with tf.name_scope("Log2"):
         | 
| 55 | 
            +
                    return tf.log(x) * np.float32(1.0 / np.log(2.0))
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def exp2(x: TfExpressionEx) -> TfExpression:
         | 
| 59 | 
            +
                """Exponent in base 2."""
         | 
| 60 | 
            +
                with tf.name_scope("Exp2"):
         | 
| 61 | 
            +
                    return tf.exp(x * np.float32(np.log(2.0)))
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
         | 
| 65 | 
            +
                """Linear interpolation."""
         | 
| 66 | 
            +
                with tf.name_scope("Lerp"):
         | 
| 67 | 
            +
                    return a + (b - a) * t
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
         | 
| 71 | 
            +
                """Linear interpolation with clip."""
         | 
| 72 | 
            +
                with tf.name_scope("LerpClip"):
         | 
| 73 | 
            +
                    return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def absolute_name_scope(scope: str) -> tf.name_scope:
         | 
| 77 | 
            +
                """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
         | 
| 78 | 
            +
                return tf.name_scope(scope + "/")
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
         | 
| 82 | 
            +
                """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
         | 
| 83 | 
            +
                return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def _sanitize_tf_config(config_dict: dict = None) -> dict:
         | 
| 87 | 
            +
                # Defaults.
         | 
| 88 | 
            +
                cfg = dict()
         | 
| 89 | 
            +
                cfg["rnd.np_random_seed"]               = None      # Random seed for NumPy. None = keep as is.
         | 
| 90 | 
            +
                cfg["rnd.tf_random_seed"]               = "auto"    # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
         | 
| 91 | 
            +
                cfg["env.TF_CPP_MIN_LOG_LEVEL"]         = "1"       # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
         | 
| 92 | 
            +
                cfg["graph_options.place_pruned_graph"] = True      # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
         | 
| 93 | 
            +
                cfg["gpu_options.allow_growth"]         = True      # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                # Remove defaults for environment variables that are already set.
         | 
| 96 | 
            +
                for key in list(cfg):
         | 
| 97 | 
            +
                    fields = key.split(".")
         | 
| 98 | 
            +
                    if fields[0] == "env":
         | 
| 99 | 
            +
                        assert len(fields) == 2
         | 
| 100 | 
            +
                        if fields[1] in os.environ:
         | 
| 101 | 
            +
                            del cfg[key]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                # User overrides.
         | 
| 104 | 
            +
                if config_dict is not None:
         | 
| 105 | 
            +
                    cfg.update(config_dict)
         | 
| 106 | 
            +
                return cfg
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
            def init_tf(config_dict: dict = None) -> None:
         | 
| 110 | 
            +
                """Initialize TensorFlow session using good default settings."""
         | 
| 111 | 
            +
                # Skip if already initialized.
         | 
| 112 | 
            +
                if tf.get_default_session() is not None:
         | 
| 113 | 
            +
                    return
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                # Setup config dict and random seeds.
         | 
| 116 | 
            +
                cfg = _sanitize_tf_config(config_dict)
         | 
| 117 | 
            +
                np_random_seed = cfg["rnd.np_random_seed"]
         | 
| 118 | 
            +
                if np_random_seed is not None:
         | 
| 119 | 
            +
                    np.random.seed(np_random_seed)
         | 
| 120 | 
            +
                tf_random_seed = cfg["rnd.tf_random_seed"]
         | 
| 121 | 
            +
                if tf_random_seed == "auto":
         | 
| 122 | 
            +
                    tf_random_seed = np.random.randint(1 << 31)
         | 
| 123 | 
            +
                if tf_random_seed is not None:
         | 
| 124 | 
            +
                    tf.set_random_seed(tf_random_seed)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                # Setup environment variables.
         | 
| 127 | 
            +
                for key, value in cfg.items():
         | 
| 128 | 
            +
                    fields = key.split(".")
         | 
| 129 | 
            +
                    if fields[0] == "env":
         | 
| 130 | 
            +
                        assert len(fields) == 2
         | 
| 131 | 
            +
                        os.environ[fields[1]] = str(value)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                # Create default TensorFlow session.
         | 
| 134 | 
            +
                create_session(cfg, force_as_default=True)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
             | 
| 137 | 
            +
            def assert_tf_initialized():
         | 
| 138 | 
            +
                """Check that TensorFlow session has been initialized."""
         | 
| 139 | 
            +
                if tf.get_default_session() is None:
         | 
| 140 | 
            +
                    raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
         | 
| 144 | 
            +
                """Create tf.Session based on config dict."""
         | 
| 145 | 
            +
                # Setup TensorFlow config proto.
         | 
| 146 | 
            +
                cfg = _sanitize_tf_config(config_dict)
         | 
| 147 | 
            +
                config_proto = tf.ConfigProto()
         | 
| 148 | 
            +
                for key, value in cfg.items():
         | 
| 149 | 
            +
                    fields = key.split(".")
         | 
| 150 | 
            +
                    if fields[0] not in ["rnd", "env"]:
         | 
| 151 | 
            +
                        obj = config_proto
         | 
| 152 | 
            +
                        for field in fields[:-1]:
         | 
| 153 | 
            +
                            obj = getattr(obj, field)
         | 
| 154 | 
            +
                        setattr(obj, fields[-1], value)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                # Create session.
         | 
| 157 | 
            +
                session = tf.Session(config=config_proto)
         | 
| 158 | 
            +
                if force_as_default:
         | 
| 159 | 
            +
                    # pylint: disable=protected-access
         | 
| 160 | 
            +
                    session._default_session = session.as_default()
         | 
| 161 | 
            +
                    session._default_session.enforce_nesting = False
         | 
| 162 | 
            +
                    session._default_session.__enter__()
         | 
| 163 | 
            +
                return session
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
         | 
| 167 | 
            +
                """Initialize all tf.Variables that have not already been initialized.
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                Equivalent to the following, but more efficient and does not bloat the tf graph:
         | 
| 170 | 
            +
                tf.variables_initializer(tf.report_uninitialized_variables()).run()
         | 
| 171 | 
            +
                """
         | 
| 172 | 
            +
                assert_tf_initialized()
         | 
| 173 | 
            +
                if target_vars is None:
         | 
| 174 | 
            +
                    target_vars = tf.global_variables()
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                test_vars = []
         | 
| 177 | 
            +
                test_ops = []
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                with tf.control_dependencies(None):  # ignore surrounding control_dependencies
         | 
| 180 | 
            +
                    for var in target_vars:
         | 
| 181 | 
            +
                        assert is_tf_expression(var)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                        try:
         | 
| 184 | 
            +
                            tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
         | 
| 185 | 
            +
                        except KeyError:
         | 
| 186 | 
            +
                            # Op does not exist => variable may be uninitialized.
         | 
| 187 | 
            +
                            test_vars.append(var)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                            with absolute_name_scope(var.name.split(":")[0]):
         | 
| 190 | 
            +
                                test_ops.append(tf.is_variable_initialized(var))
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
         | 
| 193 | 
            +
                run([var.initializer for var in init_vars])
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            def set_vars(var_to_value_dict: dict) -> None:
         | 
| 197 | 
            +
                """Set the values of given tf.Variables.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                Equivalent to the following, but more efficient and does not bloat the tf graph:
         | 
| 200 | 
            +
                tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
         | 
| 201 | 
            +
                """
         | 
| 202 | 
            +
                assert_tf_initialized()
         | 
| 203 | 
            +
                ops = []
         | 
| 204 | 
            +
                feed_dict = {}
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                for var, value in var_to_value_dict.items():
         | 
| 207 | 
            +
                    assert is_tf_expression(var)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    try:
         | 
| 210 | 
            +
                        setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0"))  # look for existing op
         | 
| 211 | 
            +
                    except KeyError:
         | 
| 212 | 
            +
                        with absolute_name_scope(var.name.split(":")[0]):
         | 
| 213 | 
            +
                            with tf.control_dependencies(None):  # ignore surrounding control_dependencies
         | 
| 214 | 
            +
                                setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter")  # create new setter
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    ops.append(setter)
         | 
| 217 | 
            +
                    feed_dict[setter.op.inputs[1]] = value
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                run(ops, feed_dict)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
             | 
| 222 | 
            +
            def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
         | 
| 223 | 
            +
                """Create tf.Variable with large initial value without bloating the tf graph."""
         | 
| 224 | 
            +
                assert_tf_initialized()
         | 
| 225 | 
            +
                assert isinstance(initial_value, np.ndarray)
         | 
| 226 | 
            +
                zeros = tf.zeros(initial_value.shape, initial_value.dtype)
         | 
| 227 | 
            +
                var = tf.Variable(zeros, *args, **kwargs)
         | 
| 228 | 
            +
                set_vars({var: initial_value})
         | 
| 229 | 
            +
                return var
         | 
| 230 | 
            +
             | 
| 231 | 
            +
             | 
| 232 | 
            +
            def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
         | 
| 233 | 
            +
                """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
         | 
| 234 | 
            +
                Can be used as an input transformation for Network.run().
         | 
| 235 | 
            +
                """
         | 
| 236 | 
            +
                images = tf.cast(images, tf.float32)
         | 
| 237 | 
            +
                if nhwc_to_nchw:
         | 
| 238 | 
            +
                    images = tf.transpose(images, [0, 3, 1, 2])
         | 
| 239 | 
            +
                return images * ((drange[1] - drange[0]) / 255) + drange[0]
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
         | 
| 243 | 
            +
                """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
         | 
| 244 | 
            +
                Can be used as an output transformation for Network.run().
         | 
| 245 | 
            +
                """
         | 
| 246 | 
            +
                images = tf.cast(images, tf.float32)
         | 
| 247 | 
            +
                if shrink > 1:
         | 
| 248 | 
            +
                    ksize = [1, 1, shrink, shrink]
         | 
| 249 | 
            +
                    images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
         | 
| 250 | 
            +
                if nchw_to_nhwc:
         | 
| 251 | 
            +
                    images = tf.transpose(images, [0, 2, 3, 1])
         | 
| 252 | 
            +
                scale = 255 / (drange[1] - drange[0])
         | 
| 253 | 
            +
                images = images * scale + (0.5 - drange[0] * scale)
         | 
| 254 | 
            +
                return tf.saturate_cast(images, tf.uint8)
         | 
    	
        dnnlib/util.py
    ADDED
    
    | @@ -0,0 +1,479 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) SenseTime Research. All rights reserved.
         | 
| 2 | 
            +
            # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         | 
| 5 | 
            +
            # and proprietary rights in and to this software, related documentation
         | 
| 6 | 
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         | 
| 7 | 
            +
            # distribution of this software and related documentation without an express
         | 
| 8 | 
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            """Miscellaneous utility classes and functions."""
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import ctypes
         | 
| 13 | 
            +
            import fnmatch
         | 
| 14 | 
            +
            import importlib
         | 
| 15 | 
            +
            import inspect
         | 
| 16 | 
            +
            import numpy as np
         | 
| 17 | 
            +
            import os
         | 
| 18 | 
            +
            import shutil
         | 
| 19 | 
            +
            import sys
         | 
| 20 | 
            +
            import types
         | 
| 21 | 
            +
            import io
         | 
| 22 | 
            +
            import pickle
         | 
| 23 | 
            +
            import re
         | 
| 24 | 
            +
            import requests
         | 
| 25 | 
            +
            import html
         | 
| 26 | 
            +
            import hashlib
         | 
| 27 | 
            +
            import glob
         | 
| 28 | 
            +
            import tempfile
         | 
| 29 | 
            +
            import urllib
         | 
| 30 | 
            +
            import urllib.request
         | 
| 31 | 
            +
            import uuid
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            from distutils.util import strtobool
         | 
| 34 | 
            +
            from typing import Any, List, Tuple, Union
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            # Util classes
         | 
| 38 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            class EasyDict(dict):
         | 
| 42 | 
            +
                """Convenience class that behaves like a dict but allows access with the attribute syntax."""
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def __getattr__(self, name: str) -> Any:
         | 
| 45 | 
            +
                    try:
         | 
| 46 | 
            +
                        return self[name]
         | 
| 47 | 
            +
                    except KeyError:
         | 
| 48 | 
            +
                        raise AttributeError(name)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def __setattr__(self, name: str, value: Any) -> None:
         | 
| 51 | 
            +
                    self[name] = value
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def __delattr__(self, name: str) -> None:
         | 
| 54 | 
            +
                    del self[name]
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class Logger(object):
         | 
| 58 | 
            +
                """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
         | 
| 61 | 
            +
                    self.file = None
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    if file_name is not None:
         | 
| 64 | 
            +
                        self.file = open(file_name, file_mode)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    self.should_flush = should_flush
         | 
| 67 | 
            +
                    self.stdout = sys.stdout
         | 
| 68 | 
            +
                    self.stderr = sys.stderr
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    sys.stdout = self
         | 
| 71 | 
            +
                    sys.stderr = self
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def __enter__(self) -> "Logger":
         | 
| 74 | 
            +
                    return self
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
         | 
| 77 | 
            +
                    self.close()
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def write(self, text: Union[str, bytes]) -> None:
         | 
| 80 | 
            +
                    """Write text to stdout (and a file) and optionally flush."""
         | 
| 81 | 
            +
                    if isinstance(text, bytes):
         | 
| 82 | 
            +
                        text = text.decode()
         | 
| 83 | 
            +
                    if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
         | 
| 84 | 
            +
                        return
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    if self.file is not None:
         | 
| 87 | 
            +
                        self.file.write(text)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    self.stdout.write(text)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    if self.should_flush:
         | 
| 92 | 
            +
                        self.flush()
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def flush(self) -> None:
         | 
| 95 | 
            +
                    """Flush written text to both stdout and a file, if open."""
         | 
| 96 | 
            +
                    if self.file is not None:
         | 
| 97 | 
            +
                        self.file.flush()
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.stdout.flush()
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def close(self) -> None:
         | 
| 102 | 
            +
                    """Flush, close possible files, and remove stdout/stderr mirroring."""
         | 
| 103 | 
            +
                    self.flush()
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    # if using multiple loggers, prevent closing in wrong order
         | 
| 106 | 
            +
                    if sys.stdout is self:
         | 
| 107 | 
            +
                        sys.stdout = self.stdout
         | 
| 108 | 
            +
                    if sys.stderr is self:
         | 
| 109 | 
            +
                        sys.stderr = self.stderr
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    if self.file is not None:
         | 
| 112 | 
            +
                        self.file.close()
         | 
| 113 | 
            +
                        self.file = None
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            # Cache directories
         | 
| 117 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            _dnnlib_cache_dir = None
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            def set_cache_dir(path: str) -> None:
         | 
| 122 | 
            +
                global _dnnlib_cache_dir
         | 
| 123 | 
            +
                _dnnlib_cache_dir = path
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            def make_cache_dir_path(*paths: str) -> str:
         | 
| 126 | 
            +
                if _dnnlib_cache_dir is not None:
         | 
| 127 | 
            +
                    return os.path.join(_dnnlib_cache_dir, *paths)
         | 
| 128 | 
            +
                if 'DNNLIB_CACHE_DIR' in os.environ:
         | 
| 129 | 
            +
                    return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
         | 
| 130 | 
            +
                if 'HOME' in os.environ:
         | 
| 131 | 
            +
                    return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
         | 
| 132 | 
            +
                if 'USERPROFILE' in os.environ:
         | 
| 133 | 
            +
                    return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
         | 
| 134 | 
            +
                return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            # Small util functions
         | 
| 137 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            def format_time(seconds: Union[int, float]) -> str:
         | 
| 141 | 
            +
                """Convert the seconds to human readable string with days, hours, minutes and seconds."""
         | 
| 142 | 
            +
                s = int(np.rint(seconds))
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                if s < 60:
         | 
| 145 | 
            +
                    return "{0}s".format(s)
         | 
| 146 | 
            +
                elif s < 60 * 60:
         | 
| 147 | 
            +
                    return "{0}m {1:02}s".format(s // 60, s % 60)
         | 
| 148 | 
            +
                elif s < 24 * 60 * 60:
         | 
| 149 | 
            +
                    return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
         | 
| 150 | 
            +
                else:
         | 
| 151 | 
            +
                    return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            def ask_yes_no(question: str) -> bool:
         | 
| 155 | 
            +
                """Ask the user the question until the user inputs a valid answer."""
         | 
| 156 | 
            +
                while True:
         | 
| 157 | 
            +
                    try:
         | 
| 158 | 
            +
                        print("{0} [y/n]".format(question))
         | 
| 159 | 
            +
                        return strtobool(input().lower())
         | 
| 160 | 
            +
                    except ValueError:
         | 
| 161 | 
            +
                        pass
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
            def tuple_product(t: Tuple) -> Any:
         | 
| 165 | 
            +
                """Calculate the product of the tuple elements."""
         | 
| 166 | 
            +
                result = 1
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                for v in t:
         | 
| 169 | 
            +
                    result *= v
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                return result
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
            _str_to_ctype = {
         | 
| 175 | 
            +
                "uint8": ctypes.c_ubyte,
         | 
| 176 | 
            +
                "uint16": ctypes.c_uint16,
         | 
| 177 | 
            +
                "uint32": ctypes.c_uint32,
         | 
| 178 | 
            +
                "uint64": ctypes.c_uint64,
         | 
| 179 | 
            +
                "int8": ctypes.c_byte,
         | 
| 180 | 
            +
                "int16": ctypes.c_int16,
         | 
| 181 | 
            +
                "int32": ctypes.c_int32,
         | 
| 182 | 
            +
                "int64": ctypes.c_int64,
         | 
| 183 | 
            +
                "float32": ctypes.c_float,
         | 
| 184 | 
            +
                "float64": ctypes.c_double
         | 
| 185 | 
            +
            }
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
         | 
| 189 | 
            +
                """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
         | 
| 190 | 
            +
                type_str = None
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                if isinstance(type_obj, str):
         | 
| 193 | 
            +
                    type_str = type_obj
         | 
| 194 | 
            +
                elif hasattr(type_obj, "__name__"):
         | 
| 195 | 
            +
                    type_str = type_obj.__name__
         | 
| 196 | 
            +
                elif hasattr(type_obj, "name"):
         | 
| 197 | 
            +
                    type_str = type_obj.name
         | 
| 198 | 
            +
                else:
         | 
| 199 | 
            +
                    raise RuntimeError("Cannot infer type name from input")
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                assert type_str in _str_to_ctype.keys()
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                my_dtype = np.dtype(type_str)
         | 
| 204 | 
            +
                my_ctype = _str_to_ctype[type_str]
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                return my_dtype, my_ctype
         | 
| 209 | 
            +
             | 
| 210 | 
            +
             | 
| 211 | 
            +
            def is_pickleable(obj: Any) -> bool:
         | 
| 212 | 
            +
                try:
         | 
| 213 | 
            +
                    with io.BytesIO() as stream:
         | 
| 214 | 
            +
                        pickle.dump(obj, stream)
         | 
| 215 | 
            +
                    return True
         | 
| 216 | 
            +
                except:
         | 
| 217 | 
            +
                    return False
         | 
| 218 | 
            +
             | 
| 219 | 
            +
             | 
| 220 | 
            +
            # Functionality to import modules/objects by name, and call functions by name
         | 
| 221 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 222 | 
            +
             | 
| 223 | 
            +
            def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
         | 
| 224 | 
            +
                """Searches for the underlying module behind the name to some python object.
         | 
| 225 | 
            +
                Returns the module and the object name (original name with module part removed)."""
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                # allow convenience shorthands, substitute them by full names
         | 
| 228 | 
            +
                obj_name = re.sub("^np.", "numpy.", obj_name)
         | 
| 229 | 
            +
                obj_name = re.sub("^tf.", "tensorflow.", obj_name)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                # list alternatives for (module_name, local_obj_name)
         | 
| 232 | 
            +
                parts = obj_name.split(".")
         | 
| 233 | 
            +
                name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                # try each alternative in turn
         | 
| 236 | 
            +
                for module_name, local_obj_name in name_pairs:
         | 
| 237 | 
            +
                    try:
         | 
| 238 | 
            +
                        module = importlib.import_module(module_name) # may raise ImportError
         | 
| 239 | 
            +
                        get_obj_from_module(module, local_obj_name) # may raise AttributeError
         | 
| 240 | 
            +
                        return module, local_obj_name
         | 
| 241 | 
            +
                    except:
         | 
| 242 | 
            +
                        pass
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                # maybe some of the modules themselves contain errors?
         | 
| 245 | 
            +
                for module_name, _local_obj_name in name_pairs:
         | 
| 246 | 
            +
                    try:
         | 
| 247 | 
            +
                        importlib.import_module(module_name) # may raise ImportError
         | 
| 248 | 
            +
                    except ImportError:
         | 
| 249 | 
            +
                        if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
         | 
| 250 | 
            +
                            raise
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                # maybe the requested attribute is missing?
         | 
| 253 | 
            +
                for module_name, local_obj_name in name_pairs:
         | 
| 254 | 
            +
                    try:
         | 
| 255 | 
            +
                        module = importlib.import_module(module_name) # may raise ImportError
         | 
| 256 | 
            +
                        get_obj_from_module(module, local_obj_name) # may raise AttributeError
         | 
| 257 | 
            +
                    except ImportError:
         | 
| 258 | 
            +
                        pass
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                # we are out of luck, but we have no idea why
         | 
| 261 | 
            +
                raise ImportError(obj_name)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
             | 
| 264 | 
            +
            def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
         | 
| 265 | 
            +
                """Traverses the object name and returns the last (rightmost) python object."""
         | 
| 266 | 
            +
                if obj_name == '':
         | 
| 267 | 
            +
                    return module
         | 
| 268 | 
            +
                obj = module
         | 
| 269 | 
            +
                for part in obj_name.split("."):
         | 
| 270 | 
            +
                    obj = getattr(obj, part)
         | 
| 271 | 
            +
                return obj
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            def get_obj_by_name(name: str) -> Any:
         | 
| 275 | 
            +
                """Finds the python object with the given name."""
         | 
| 276 | 
            +
                module, obj_name = get_module_from_obj_name(name)
         | 
| 277 | 
            +
                return get_obj_from_module(module, obj_name)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
            def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
         | 
| 281 | 
            +
                """Finds the python object with the given name and calls it as a function."""
         | 
| 282 | 
            +
                assert func_name is not None
         | 
| 283 | 
            +
                # print('func_name: ', func_name) #'training.dataset.ImageFolderDataset'
         | 
| 284 | 
            +
                func_obj = get_obj_by_name(func_name) 
         | 
| 285 | 
            +
                assert callable(func_obj)
         | 
| 286 | 
            +
                return func_obj(*args, **kwargs)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
             | 
| 289 | 
            +
            def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
         | 
| 290 | 
            +
                """Finds the python class with the given name and constructs it with the given arguments."""
         | 
| 291 | 
            +
                return call_func_by_name(*args, func_name=class_name, **kwargs)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
             | 
| 294 | 
            +
            def get_module_dir_by_obj_name(obj_name: str) -> str:
         | 
| 295 | 
            +
                """Get the directory path of the module containing the given object name."""
         | 
| 296 | 
            +
                module, _ = get_module_from_obj_name(obj_name)
         | 
| 297 | 
            +
                return os.path.dirname(inspect.getfile(module))
         | 
| 298 | 
            +
             | 
| 299 | 
            +
             | 
| 300 | 
            +
            def is_top_level_function(obj: Any) -> bool:
         | 
| 301 | 
            +
                """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
         | 
| 302 | 
            +
                return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
         | 
| 303 | 
            +
             | 
| 304 | 
            +
             | 
| 305 | 
            +
            def get_top_level_function_name(obj: Any) -> str:
         | 
| 306 | 
            +
                """Return the fully-qualified name of a top-level function."""
         | 
| 307 | 
            +
                assert is_top_level_function(obj)
         | 
| 308 | 
            +
                module = obj.__module__
         | 
| 309 | 
            +
                if module == '__main__':
         | 
| 310 | 
            +
                    module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
         | 
| 311 | 
            +
                return module + "." + obj.__name__
         | 
| 312 | 
            +
             | 
| 313 | 
            +
             | 
| 314 | 
            +
            # File system helpers
         | 
| 315 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 316 | 
            +
             | 
| 317 | 
            +
            def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
         | 
| 318 | 
            +
                """List all files recursively in a given directory while ignoring given file and directory names.
         | 
| 319 | 
            +
                Returns list of tuples containing both absolute and relative paths."""
         | 
| 320 | 
            +
                assert os.path.isdir(dir_path)
         | 
| 321 | 
            +
                base_name = os.path.basename(os.path.normpath(dir_path))
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                if ignores is None:
         | 
| 324 | 
            +
                    ignores = []
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                result = []
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                for root, dirs, files in os.walk(dir_path, topdown=True):
         | 
| 329 | 
            +
                    for ignore_ in ignores:
         | 
| 330 | 
            +
                        dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                        # dirs need to be edited in-place
         | 
| 333 | 
            +
                        for d in dirs_to_remove:
         | 
| 334 | 
            +
                            dirs.remove(d)
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                        files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    absolute_paths = [os.path.join(root, f) for f in files]
         | 
| 339 | 
            +
                    relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    if add_base_to_relative:
         | 
| 342 | 
            +
                        relative_paths = [os.path.join(base_name, p) for p in relative_paths]
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    assert len(absolute_paths) == len(relative_paths)
         | 
| 345 | 
            +
                    result += zip(absolute_paths, relative_paths)
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                return result
         | 
| 348 | 
            +
             | 
| 349 | 
            +
             | 
| 350 | 
            +
            def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
         | 
| 351 | 
            +
                """Takes in a list of tuples of (src, dst) paths and copies files.
         | 
| 352 | 
            +
                Will create all necessary directories."""
         | 
| 353 | 
            +
                for file in files:
         | 
| 354 | 
            +
                    target_dir_name = os.path.dirname(file[1])
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    # will create all intermediate-level directories
         | 
| 357 | 
            +
                    if not os.path.exists(target_dir_name):
         | 
| 358 | 
            +
                        os.makedirs(target_dir_name)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    shutil.copyfile(file[0], file[1])
         | 
| 361 | 
            +
             | 
| 362 | 
            +
             | 
| 363 | 
            +
            # URL helpers
         | 
| 364 | 
            +
            # ------------------------------------------------------------------------------------------
         | 
| 365 | 
            +
             | 
| 366 | 
            +
            def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
         | 
| 367 | 
            +
                """Determine whether the given object is a valid URL string."""
         | 
| 368 | 
            +
                if not isinstance(obj, str) or not "://" in obj:
         | 
| 369 | 
            +
                    return False
         | 
| 370 | 
            +
                if allow_file_urls and obj.startswith('file://'):
         | 
| 371 | 
            +
                    return True
         | 
| 372 | 
            +
                try:
         | 
| 373 | 
            +
                    res = requests.compat.urlparse(obj)
         | 
| 374 | 
            +
                    if not res.scheme or not res.netloc or not "." in res.netloc:
         | 
| 375 | 
            +
                        return False
         | 
| 376 | 
            +
                    res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
         | 
| 377 | 
            +
                    if not res.scheme or not res.netloc or not "." in res.netloc:
         | 
| 378 | 
            +
                        return False
         | 
| 379 | 
            +
                except:
         | 
| 380 | 
            +
                    return False
         | 
| 381 | 
            +
                return True
         | 
| 382 | 
            +
             | 
| 383 | 
            +
             | 
| 384 | 
            +
            def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
         | 
| 385 | 
            +
                """Download the given URL and return a binary-mode file object to access the data."""
         | 
| 386 | 
            +
                assert num_attempts >= 1
         | 
| 387 | 
            +
                assert not (return_filename and (not cache))
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                # Doesn't look like an URL scheme so interpret it as a local filename.
         | 
| 390 | 
            +
                if not re.match('^[a-z]+://', url):
         | 
| 391 | 
            +
                    return url if return_filename else open(url, "rb")
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                # Handle file URLs.  This code handles unusual file:// patterns that
         | 
| 394 | 
            +
                # arise on Windows:
         | 
| 395 | 
            +
                #
         | 
| 396 | 
            +
                # file:///c:/foo.txt
         | 
| 397 | 
            +
                #
         | 
| 398 | 
            +
                # which would translate to a local '/c:/foo.txt' filename that's
         | 
| 399 | 
            +
                # invalid.  Drop the forward slash for such pathnames.
         | 
| 400 | 
            +
                #
         | 
| 401 | 
            +
                # If you touch this code path, you should test it on both Linux and
         | 
| 402 | 
            +
                # Windows.
         | 
| 403 | 
            +
                #
         | 
| 404 | 
            +
                # Some internet resources suggest using urllib.request.url2pathname() but
         | 
| 405 | 
            +
                # but that converts forward slashes to backslashes and this causes
         | 
| 406 | 
            +
                # its own set of problems.
         | 
| 407 | 
            +
                if url.startswith('file://'):
         | 
| 408 | 
            +
                    filename = urllib.parse.urlparse(url).path
         | 
| 409 | 
            +
                    if re.match(r'^/[a-zA-Z]:', filename):
         | 
| 410 | 
            +
                        filename = filename[1:]
         | 
| 411 | 
            +
                    return filename if return_filename else open(filename, "rb")
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                assert is_url(url)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                # Lookup from cache.
         | 
| 416 | 
            +
                if cache_dir is None:
         | 
| 417 | 
            +
                    cache_dir = make_cache_dir_path('downloads')
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
         | 
| 420 | 
            +
                if cache:
         | 
| 421 | 
            +
                    cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
         | 
| 422 | 
            +
                    if len(cache_files) == 1:
         | 
| 423 | 
            +
                        filename = cache_files[0]
         | 
| 424 | 
            +
                        return filename if return_filename else open(filename, "rb")
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                # Download.
         | 
| 427 | 
            +
                url_name = None
         | 
| 428 | 
            +
                url_data = None
         | 
| 429 | 
            +
                with requests.Session() as session:
         | 
| 430 | 
            +
                    if verbose:
         | 
| 431 | 
            +
                        print("Downloading %s ..." % url, end="", flush=True)
         | 
| 432 | 
            +
                    for attempts_left in reversed(range(num_attempts)):
         | 
| 433 | 
            +
                        try:
         | 
| 434 | 
            +
                            with session.get(url) as res:
         | 
| 435 | 
            +
                                res.raise_for_status()
         | 
| 436 | 
            +
                                if len(res.content) == 0:
         | 
| 437 | 
            +
                                    raise IOError("No data received")
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                                if len(res.content) < 8192:
         | 
| 440 | 
            +
                                    content_str = res.content.decode("utf-8")
         | 
| 441 | 
            +
                                    if "download_warning" in res.headers.get("Set-Cookie", ""):
         | 
| 442 | 
            +
                                        links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
         | 
| 443 | 
            +
                                        if len(links) == 1:
         | 
| 444 | 
            +
                                            url = requests.compat.urljoin(url, links[0])
         | 
| 445 | 
            +
                                            raise IOError("Google Drive virus checker nag")
         | 
| 446 | 
            +
                                    if "Google Drive - Quota exceeded" in content_str:
         | 
| 447 | 
            +
                                        raise IOError("Google Drive download quota exceeded -- please try again later")
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                                match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
         | 
| 450 | 
            +
                                url_name = match[1] if match else url
         | 
| 451 | 
            +
                                url_data = res.content
         | 
| 452 | 
            +
                                if verbose:
         | 
| 453 | 
            +
                                    print(" done")
         | 
| 454 | 
            +
                                break
         | 
| 455 | 
            +
                        except KeyboardInterrupt:
         | 
| 456 | 
            +
                            raise
         | 
| 457 | 
            +
                        except:
         | 
| 458 | 
            +
                            if not attempts_left:
         | 
| 459 | 
            +
                                if verbose:
         | 
| 460 | 
            +
                                    print(" failed")
         | 
| 461 | 
            +
                                raise
         | 
| 462 | 
            +
                            if verbose:
         | 
| 463 | 
            +
                                print(".", end="", flush=True)
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                # Save to cache.
         | 
| 466 | 
            +
                if cache:
         | 
| 467 | 
            +
                    safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
         | 
| 468 | 
            +
                    cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
         | 
| 469 | 
            +
                    temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
         | 
| 470 | 
            +
                    os.makedirs(cache_dir, exist_ok=True)
         | 
| 471 | 
            +
                    with open(temp_file, "wb") as f:
         | 
| 472 | 
            +
                        f.write(url_data)
         | 
| 473 | 
            +
                    os.replace(temp_file, cache_file) # atomic
         | 
| 474 | 
            +
                    if return_filename:
         | 
| 475 | 
            +
                        return cache_file
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                # Return data as file object.
         | 
| 478 | 
            +
                assert not return_filename
         | 
| 479 | 
            +
                return io.BytesIO(url_data)
         |