Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # This code is modified from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long | |
| """Hyperparameter values.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import json | |
| import numbers | |
| import re | |
| import six | |
| # Define the regular expression for parsing a single clause of the input | |
| # (delimited by commas). A legal clause looks like: | |
| # <variable name>[<index>]? = <rhs> | |
| # where <rhs> is either a single token or [] enclosed list of tokens. | |
| # For example: "var[1] = a" or "x = [1,2,3]" | |
| PARAM_RE = re.compile( | |
| r""" | |
| (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x" | |
| (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None | |
| \s*=\s* | |
| ((?P<val>[^,\[]*) # single value: "a" or None | |
| | | |
| \[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3" | |
| ($|,\s*)""", | |
| re.VERBOSE, | |
| ) | |
| def _parse_fail(name, var_type, value, values): | |
| """Helper function for raising a value error for bad assignment.""" | |
| raise ValueError( | |
| "Could not parse hparam '%s' of type '%s' with value '%s' in %s" | |
| % (name, var_type.__name__, value, values) | |
| ) | |
| def _reuse_fail(name, values): | |
| """Helper function for raising a value error for reuse of name.""" | |
| raise ValueError("Multiple assignments to variable '%s' in %s" % (name, values)) | |
| def _process_scalar_value(name, parse_fn, var_type, m_dict, values, results_dictionary): | |
| """Update results_dictionary with a scalar value. | |
| Used to update the results_dictionary to be returned by parse_values when | |
| encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) | |
| Mutates results_dictionary. | |
| Args: | |
| name: Name of variable in assignment ("s" or "arr"). | |
| parse_fn: Function for parsing the actual value. | |
| var_type: Type of named variable. | |
| m_dict: Dictionary constructed from regex parsing. | |
| m_dict['val']: RHS value (scalar) | |
| m_dict['index']: List index value (or None) | |
| values: Full expression being parsed | |
| results_dictionary: The dictionary being updated for return by the parsing | |
| function. | |
| Raises: | |
| ValueError: If the name has already been used. | |
| """ | |
| try: | |
| parsed_value = parse_fn(m_dict["val"]) | |
| except ValueError: | |
| _parse_fail(name, var_type, m_dict["val"], values) | |
| # If no index is provided | |
| if not m_dict["index"]: | |
| if name in results_dictionary: | |
| _reuse_fail(name, values) | |
| results_dictionary[name] = parsed_value | |
| else: | |
| if name in results_dictionary: | |
| # The name has already been used as a scalar, then it | |
| # will be in this dictionary and map to a non-dictionary. | |
| if not isinstance(results_dictionary.get(name), dict): | |
| _reuse_fail(name, values) | |
| else: | |
| results_dictionary[name] = {} | |
| index = int(m_dict["index"]) | |
| # Make sure the index position hasn't already been assigned a value. | |
| if index in results_dictionary[name]: | |
| _reuse_fail("{}[{}]".format(name, index), values) | |
| results_dictionary[name][index] = parsed_value | |
| def _process_list_value(name, parse_fn, var_type, m_dict, values, results_dictionary): | |
| """Update results_dictionary from a list of values. | |
| Used to update results_dictionary to be returned by parse_values when | |
| encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) | |
| Mutates results_dictionary. | |
| Args: | |
| name: Name of variable in assignment ("arr"). | |
| parse_fn: Function for parsing individual values. | |
| var_type: Type of named variable. | |
| m_dict: Dictionary constructed from regex parsing. | |
| m_dict['val']: RHS value (scalar) | |
| values: Full expression being parsed | |
| results_dictionary: The dictionary being updated for return by the parsing | |
| function. | |
| Raises: | |
| ValueError: If the name has an index or the values cannot be parsed. | |
| """ | |
| if m_dict["index"] is not None: | |
| raise ValueError("Assignment of a list to a list index.") | |
| elements = filter(None, re.split("[ ,]", m_dict["vals"])) | |
| # Make sure the name hasn't already been assigned a value | |
| if name in results_dictionary: | |
| raise _reuse_fail(name, values) | |
| try: | |
| results_dictionary[name] = [parse_fn(e) for e in elements] | |
| except ValueError: | |
| _parse_fail(name, var_type, m_dict["vals"], values) | |
| def _cast_to_type_if_compatible(name, param_type, value): | |
| """Cast hparam to the provided type, if compatible. | |
| Args: | |
| name: Name of the hparam to be cast. | |
| param_type: The type of the hparam. | |
| value: The value to be cast, if compatible. | |
| Returns: | |
| The result of casting `value` to `param_type`. | |
| Raises: | |
| ValueError: If the type of `value` is not compatible with param_type. | |
| * If `param_type` is a string type, but `value` is not. | |
| * If `param_type` is a boolean, but `value` is not, or vice versa. | |
| * If `param_type` is an integer type, but `value` is not. | |
| * If `param_type` is a float type, but `value` is not a numeric type. | |
| """ | |
| fail_msg = "Could not cast hparam '%s' of type '%s' from value %r" % ( | |
| name, | |
| param_type, | |
| value, | |
| ) | |
| # Some callers use None, for which we can't do any casting/checking. :( | |
| if issubclass(param_type, type(None)): | |
| return value | |
| # Avoid converting a non-string type to a string. | |
| if issubclass(param_type, (six.string_types, six.binary_type)) and not isinstance( | |
| value, (six.string_types, six.binary_type) | |
| ): | |
| raise ValueError(fail_msg) | |
| # Avoid converting a number or string type to a boolean or vice versa. | |
| if issubclass(param_type, bool) != isinstance(value, bool): | |
| raise ValueError(fail_msg) | |
| # Avoid converting float to an integer (the reverse is fine). | |
| if issubclass(param_type, numbers.Integral) and not isinstance( | |
| value, numbers.Integral | |
| ): | |
| raise ValueError(fail_msg) | |
| # Avoid converting a non-numeric type to a numeric type. | |
| if issubclass(param_type, numbers.Number) and not isinstance(value, numbers.Number): | |
| raise ValueError(fail_msg) | |
| return param_type(value) | |
| def parse_values(values, type_map, ignore_unknown=False): | |
| """Parses hyperparameter values from a string into a python map. | |
| `values` is a string containing comma-separated `name=value` pairs. | |
| For each pair, the value of the hyperparameter named `name` is set to | |
| `value`. | |
| If a hyperparameter name appears multiple times in `values`, a ValueError | |
| is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). | |
| If a hyperparameter name in both an index assignment and scalar assignment, | |
| a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). | |
| The hyperparameter name may contain '.' symbols, which will result in an | |
| attribute name that is only accessible through the getattr and setattr | |
| functions. (And must be first explicit added through add_hparam.) | |
| WARNING: Use of '.' in your variable names is allowed, but is not well | |
| supported and not recommended. | |
| The `value` in `name=value` must follows the syntax according to the | |
| type of the parameter: | |
| * Scalar integer: A Python-parsable integer point value. E.g.: 1, | |
| 100, -12. | |
| * Scalar float: A Python-parsable floating point value. E.g.: 1.0, | |
| -.54e89. | |
| * Boolean: Either true or false. | |
| * Scalar string: A non-empty sequence of characters, excluding comma, | |
| spaces, and square brackets. E.g.: foo, bar_1. | |
| * List: A comma separated list of scalar values of the parameter type | |
| enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. | |
| When index assignment is used, the corresponding type_map key should be the | |
| list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not | |
| "arr[1]"). | |
| Args: | |
| values: String. Comma separated list of `name=value` pairs where | |
| 'value' must follow the syntax described above. | |
| type_map: A dictionary mapping hyperparameter names to types. Note every | |
| parameter name in values must be a key in type_map. The values must | |
| conform to the types indicated, where a value V is said to conform to a | |
| type T if either V has type T, or V is a list of elements of type T. | |
| Hence, for a multidimensional parameter 'x' taking float values, | |
| 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. | |
| ignore_unknown: Bool. Whether values that are missing a type in type_map | |
| should be ignored. If set to True, a ValueError will not be raised for | |
| unknown hyperparameter type. | |
| Returns: | |
| A python map mapping each name to either: | |
| * A scalar value. | |
| * A list of scalar values. | |
| * A dictionary mapping index numbers to scalar values. | |
| (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") | |
| Raises: | |
| ValueError: If there is a problem with input. | |
| * If `values` cannot be parsed. | |
| * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). | |
| * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', | |
| 'a[1]=1,a[1]=2', or 'a=1,a=[1]') | |
| """ | |
| results_dictionary = {} | |
| pos = 0 | |
| while pos < len(values): | |
| m = PARAM_RE.match(values, pos) | |
| if not m: | |
| raise ValueError("Malformed hyperparameter value: %s" % values[pos:]) | |
| # Check that there is a comma between parameters and move past it. | |
| pos = m.end() | |
| # Parse the values. | |
| m_dict = m.groupdict() | |
| name = m_dict["name"] | |
| if name not in type_map: | |
| if ignore_unknown: | |
| continue | |
| raise ValueError("Unknown hyperparameter type for %s" % name) | |
| type_ = type_map[name] | |
| # Set up correct parsing function (depending on whether type_ is a bool) | |
| if type_ == bool: | |
| def parse_bool(value): | |
| if value in ["true", "True"]: | |
| return True | |
| elif value in ["false", "False"]: | |
| return False | |
| else: | |
| try: | |
| return bool(int(value)) | |
| except ValueError: | |
| _parse_fail(name, type_, value, values) | |
| parse = parse_bool | |
| else: | |
| parse = type_ | |
| # If a singe value is provided | |
| if m_dict["val"] is not None: | |
| _process_scalar_value( | |
| name, parse, type_, m_dict, values, results_dictionary | |
| ) | |
| # If the assigned value is a list: | |
| elif m_dict["vals"] is not None: | |
| _process_list_value(name, parse, type_, m_dict, values, results_dictionary) | |
| else: # Not assigned a list or value | |
| _parse_fail(name, type_, "", values) | |
| return results_dictionary | |
| class HParams(object): | |
| """Class to hold a set of hyperparameters as name-value pairs. | |
| A `HParams` object holds hyperparameters used to build and train a model, | |
| such as the number of hidden units in a neural net layer or the learning rate | |
| to use when training. | |
| You first create a `HParams` object by specifying the names and values of the | |
| hyperparameters. | |
| To make them easily accessible the parameter names are added as direct | |
| attributes of the class. A typical usage is as follows: | |
| ```python | |
| # Create a HParams object specifying names and values of the model | |
| # hyperparameters: | |
| hparams = HParams(learning_rate=0.1, num_hidden_units=100) | |
| # The hyperparameter are available as attributes of the HParams object: | |
| hparams.learning_rate ==> 0.1 | |
| hparams.num_hidden_units ==> 100 | |
| ``` | |
| Hyperparameters have type, which is inferred from the type of their value | |
| passed at construction type. The currently supported types are: integer, | |
| float, boolean, string, and list of integer, float, boolean, or string. | |
| You can override hyperparameter values by calling the | |
| [`parse()`](#HParams.parse) method, passing a string of comma separated | |
| `name=value` pairs. This is intended to make it possible to override | |
| any hyperparameter values from a single command-line flag to which | |
| the user passes 'hyper-param=value' pairs. It avoids having to define | |
| one flag for each hyperparameter. | |
| The syntax expected for each value depends on the type of the parameter. | |
| See `parse()` for a description of the syntax. | |
| Example: | |
| ```python | |
| # Define a command line flag to pass name=value pairs. | |
| # For example using argparse: | |
| import argparse | |
| parser = argparse.ArgumentParser(description='Train my model.') | |
| parser.add_argument('--hparams', type=str, | |
| help='Comma separated list of "name=value" pairs.') | |
| args = parser.parse_args() | |
| ... | |
| def my_program(): | |
| # Create a HParams object specifying the names and values of the | |
| # model hyperparameters: | |
| hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, | |
| activations=['relu', 'tanh']) | |
| # Override hyperparameters values by parsing the command line | |
| hparams.parse(args.hparams) | |
| # If the user passed `--hparams=learning_rate=0.3` on the command line | |
| # then 'hparams' has the following attributes: | |
| hparams.learning_rate ==> 0.3 | |
| hparams.num_hidden_units ==> 100 | |
| hparams.activations ==> ['relu', 'tanh'] | |
| # If the hyperparameters are in json format use parse_json: | |
| hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') | |
| ``` | |
| """ | |
| _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. | |
| def __init__(self, model_structure=None, **kwargs): | |
| """Create an instance of `HParams` from keyword arguments. | |
| The keyword arguments specify name-values pairs for the hyperparameters. | |
| The parameter types are inferred from the type of the values passed. | |
| The parameter names are added as attributes of `HParams` object, so they | |
| can be accessed directly with the dot notation `hparams._name_`. | |
| Example: | |
| ```python | |
| # Define 3 hyperparameters: 'learning_rate' is a float parameter, | |
| # 'num_hidden_units' an integer parameter, and 'activation' a string | |
| # parameter. | |
| hparams = tf.HParams( | |
| learning_rate=0.1, num_hidden_units=100, activation='relu') | |
| hparams.activation ==> 'relu' | |
| ``` | |
| Note that a few names are reserved and cannot be used as hyperparameter | |
| names. If you use one of the reserved name the constructor raises a | |
| `ValueError`. | |
| Args: | |
| model_structure: An instance of ModelStructure, defining the feature | |
| crosses to be used in the Trial. | |
| **kwargs: Key-value pairs where the key is the hyperparameter name and | |
| the value is the value for the parameter. | |
| Raises: | |
| ValueError: If both `hparam_def` and initialization values are provided, | |
| or if one of the arguments is invalid. | |
| """ | |
| # Register the hyperparameters and their type in _hparam_types. | |
| # This simplifies the implementation of parse(). | |
| # _hparam_types maps the parameter name to a tuple (type, bool). | |
| # The type value is the type of the parameter for scalar hyperparameters, | |
| # or the type of the list elements for multidimensional hyperparameters. | |
| # The bool value is True if the value is a list, False otherwise. | |
| self._hparam_types = {} | |
| self._model_structure = model_structure | |
| for name, value in six.iteritems(kwargs): | |
| self.add_hparam(name, value) | |
| def add_hparam(self, name, value): | |
| """Adds {name, value} pair to hyperparameters. | |
| Args: | |
| name: Name of the hyperparameter. | |
| value: Value of the hyperparameter. Can be one of the following types: | |
| int, float, string, int list, float list, or string list. | |
| Raises: | |
| ValueError: if one of the arguments is invalid. | |
| """ | |
| # Keys in kwargs are unique, but 'name' could the name of a pre-existing | |
| # attribute of this object. In that case we refuse to use it as a | |
| # hyperparameter name. | |
| if getattr(self, name, None) is not None: | |
| raise ValueError("Hyperparameter name is reserved: %s" % name) | |
| if isinstance(value, (list, tuple)): | |
| if not value: | |
| raise ValueError( | |
| "Multi-valued hyperparameters cannot be empty: %s" % name | |
| ) | |
| self._hparam_types[name] = (type(value[0]), True) | |
| else: | |
| self._hparam_types[name] = (type(value), False) | |
| setattr(self, name, value) | |
| def set_hparam(self, name, value): | |
| """Set the value of an existing hyperparameter. | |
| This function verifies that the type of the value matches the type of the | |
| existing hyperparameter. | |
| Args: | |
| name: Name of the hyperparameter. | |
| value: New value of the hyperparameter. | |
| Raises: | |
| KeyError: If the hyperparameter doesn't exist. | |
| ValueError: If there is a type mismatch. | |
| """ | |
| param_type, is_list = self._hparam_types[name] | |
| if isinstance(value, list): | |
| if not is_list: | |
| raise ValueError( | |
| "Must not pass a list for single-valued parameter: %s" % name | |
| ) | |
| setattr( | |
| self, | |
| name, | |
| [_cast_to_type_if_compatible(name, param_type, v) for v in value], | |
| ) | |
| else: | |
| if is_list: | |
| raise ValueError( | |
| "Must pass a list for multi-valued parameter: %s." % name | |
| ) | |
| setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) | |
| def del_hparam(self, name): | |
| """Removes the hyperparameter with key 'name'. | |
| Does nothing if it isn't present. | |
| Args: | |
| name: Name of the hyperparameter. | |
| """ | |
| if hasattr(self, name): | |
| delattr(self, name) | |
| del self._hparam_types[name] | |
| def parse(self, values): | |
| """Override existing hyperparameter values, parsing new values from a string. | |
| See parse_values for more detail on the allowed format for values. | |
| Args: | |
| values: String. Comma separated list of `name=value` pairs where 'value' | |
| must follow the syntax described above. | |
| Returns: | |
| The `HParams` instance. | |
| Raises: | |
| ValueError: If `values` cannot be parsed or a hyperparameter in `values` | |
| doesn't exist. | |
| """ | |
| type_map = {} | |
| for name, t in self._hparam_types.items(): | |
| param_type, _ = t | |
| type_map[name] = param_type | |
| values_map = parse_values(values, type_map) | |
| return self.override_from_dict(values_map) | |
| def override_from_dict(self, values_dict): | |
| """Override existing hyperparameter values, parsing new values from a dictionary. | |
| Args: | |
| values_dict: Dictionary of name:value pairs. | |
| Returns: | |
| The `HParams` instance. | |
| Raises: | |
| KeyError: If a hyperparameter in `values_dict` doesn't exist. | |
| ValueError: If `values_dict` cannot be parsed. | |
| """ | |
| for name, value in values_dict.items(): | |
| self.set_hparam(name, value) | |
| return self | |
| def set_model_structure(self, model_structure): | |
| self._model_structure = model_structure | |
| def get_model_structure(self): | |
| return self._model_structure | |
| def to_json(self, indent=None, separators=None, sort_keys=False): | |
| """Serializes the hyperparameters into JSON. | |
| Args: | |
| indent: If a non-negative integer, JSON array elements and object members | |
| will be pretty-printed with that indent level. An indent level of 0, or | |
| negative, will only insert newlines. `None` (the default) selects the | |
| most compact representation. | |
| separators: Optional `(item_separator, key_separator)` tuple. Default is | |
| `(', ', ': ')`. | |
| sort_keys: If `True`, the output dictionaries will be sorted by key. | |
| Returns: | |
| A JSON string. | |
| """ | |
| def remove_callables(x): | |
| """Omit callable elements from input with arbitrary nesting.""" | |
| if isinstance(x, dict): | |
| return { | |
| k: remove_callables(v) | |
| for k, v in six.iteritems(x) | |
| if not callable(v) | |
| } | |
| elif isinstance(x, list): | |
| return [remove_callables(i) for i in x if not callable(i)] | |
| return x | |
| return json.dumps( | |
| remove_callables(self.values()), | |
| indent=indent, | |
| separators=separators, | |
| sort_keys=sort_keys, | |
| ) | |
| def parse_json(self, values_json): | |
| """Override existing hyperparameter values, parsing new values from a json object. | |
| Args: | |
| values_json: String containing a json object of name:value pairs. | |
| Returns: | |
| The `HParams` instance. | |
| Raises: | |
| KeyError: If a hyperparameter in `values_json` doesn't exist. | |
| ValueError: If `values_json` cannot be parsed. | |
| """ | |
| values_map = json.loads(values_json) | |
| return self.override_from_dict(values_map) | |
| def values(self): | |
| """Return the hyperparameter values as a Python dictionary. | |
| Returns: | |
| A dictionary with hyperparameter names as keys. The values are the | |
| hyperparameter values. | |
| """ | |
| return {n: getattr(self, n) for n in self._hparam_types.keys()} | |
| def get(self, key, default=None): | |
| """Returns the value of `key` if it exists, else `default`.""" | |
| if key in self._hparam_types: | |
| # Ensure that default is compatible with the parameter type. | |
| if default is not None: | |
| param_type, is_param_list = self._hparam_types[key] | |
| type_str = "list<%s>" % param_type if is_param_list else str(param_type) | |
| fail_msg = ( | |
| "Hparam '%s' of type '%s' is incompatible with " | |
| "default=%s" % (key, type_str, default) | |
| ) | |
| is_default_list = isinstance(default, list) | |
| if is_param_list != is_default_list: | |
| raise ValueError(fail_msg) | |
| try: | |
| if is_default_list: | |
| for value in default: | |
| _cast_to_type_if_compatible(key, param_type, value) | |
| else: | |
| _cast_to_type_if_compatible(key, param_type, default) | |
| except ValueError as e: | |
| raise ValueError("%s. %s" % (fail_msg, e)) | |
| return getattr(self, key) | |
| return default | |
| def __contains__(self, key): | |
| return key in self._hparam_types | |
| def __str__(self): | |
| return str(sorted(self.values().items())) | |
| def __repr__(self): | |
| return "%s(%s)" % (type(self).__name__, self.__str__()) | |
| def _get_kind_name(param_type, is_list): | |
| """Returns the field name given parameter type and is_list. | |
| Args: | |
| param_type: Data type of the hparam. | |
| is_list: Whether this is a list. | |
| Returns: | |
| A string representation of the field name. | |
| Raises: | |
| ValueError: If parameter type is not recognized. | |
| """ | |
| if issubclass(param_type, bool): | |
| # This check must happen before issubclass(param_type, six.integer_types), | |
| # since Python considers bool to be a subclass of int. | |
| typename = "bool" | |
| elif issubclass(param_type, six.integer_types): | |
| # Setting 'int' and 'long' types to be 'int64' to ensure the type is | |
| # compatible with both Python2 and Python3. | |
| typename = "int64" | |
| elif issubclass(param_type, (six.string_types, six.binary_type)): | |
| # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is | |
| # compatible with both Python2 and Python3. | |
| typename = "bytes" | |
| elif issubclass(param_type, float): | |
| typename = "float" | |
| else: | |
| raise ValueError("Unsupported parameter type: %s" % str(param_type)) | |
| suffix = "list" if is_list else "value" | |
| return "_".join([typename, suffix]) | |