folder adjust
This commit is contained in:
parent
5b63663aeb
commit
320633a419
|
@ -1,731 +0,0 @@
|
||||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Hyperparameter values."""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import json
|
|
||||||
import numbers
|
|
||||||
import re
|
|
||||||
|
|
||||||
import six
|
|
||||||
|
|
||||||
## from tensorflow.contrib.training.python.training import hparam_pb2
|
|
||||||
## from tensorflow.python.framework import ops
|
|
||||||
## from tensorflow.python.util import compat
|
|
||||||
## from tensorflow.python.util import deprecation
|
|
||||||
|
|
||||||
# Define the regular expression for parsing a single clause of the input
|
|
||||||
# (delimited by commas). A legal clause looks like:
|
|
||||||
# <variable name>[<index>]? = <rhs>
|
|
||||||
# where <rhs> is either a single token or [] enclosed list of tokens.
|
|
||||||
# For example: "var[1] = a" or "x = [1,2,3]"
|
|
||||||
PARAM_RE = re.compile(r"""
|
|
||||||
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
|
|
||||||
(\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
|
|
||||||
\s*=\s*
|
|
||||||
((?P<val>[^,\[]*) # single value: "a" or None
|
|
||||||
|
|
|
||||||
\[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
|
|
||||||
($|,\s*)""", re.VERBOSE)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_fail(name, var_type, value, values):
|
|
||||||
"""Helper function for raising a value error for bad assignment."""
|
|
||||||
raise ValueError(
|
|
||||||
'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' %
|
|
||||||
(name, var_type.__name__, value, values))
|
|
||||||
|
|
||||||
|
|
||||||
def _reuse_fail(name, values):
|
|
||||||
"""Helper function for raising a value error for reuse of name."""
|
|
||||||
raise ValueError('Multiple assignments to variable \'%s\' in %s' %
|
|
||||||
(name, values))
|
|
||||||
|
|
||||||
|
|
||||||
def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
|
|
||||||
results_dictionary):
|
|
||||||
"""Update results_dictionary with a scalar value.
|
|
||||||
|
|
||||||
Used to update the results_dictionary to be returned by parse_values when
|
|
||||||
encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
|
|
||||||
|
|
||||||
Mutates results_dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of variable in assignment ("s" or "arr").
|
|
||||||
parse_fn: Function for parsing the actual value.
|
|
||||||
var_type: Type of named variable.
|
|
||||||
m_dict: Dictionary constructed from regex parsing.
|
|
||||||
m_dict['val']: RHS value (scalar)
|
|
||||||
m_dict['index']: List index value (or None)
|
|
||||||
values: Full expression being parsed
|
|
||||||
results_dictionary: The dictionary being updated for return by the parsing
|
|
||||||
function.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the name has already been used.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
parsed_value = parse_fn(m_dict['val'])
|
|
||||||
except ValueError:
|
|
||||||
_parse_fail(name, var_type, m_dict['val'], values)
|
|
||||||
|
|
||||||
# If no index is provided
|
|
||||||
if not m_dict['index']:
|
|
||||||
if name in results_dictionary:
|
|
||||||
_reuse_fail(name, values)
|
|
||||||
results_dictionary[name] = parsed_value
|
|
||||||
else:
|
|
||||||
if name in results_dictionary:
|
|
||||||
# The name has already been used as a scalar, then it
|
|
||||||
# will be in this dictionary and map to a non-dictionary.
|
|
||||||
if not isinstance(results_dictionary.get(name), dict):
|
|
||||||
_reuse_fail(name, values)
|
|
||||||
else:
|
|
||||||
results_dictionary[name] = {}
|
|
||||||
|
|
||||||
index = int(m_dict['index'])
|
|
||||||
# Make sure the index position hasn't already been assigned a value.
|
|
||||||
if index in results_dictionary[name]:
|
|
||||||
_reuse_fail('{}[{}]'.format(name, index), values)
|
|
||||||
results_dictionary[name][index] = parsed_value
|
|
||||||
|
|
||||||
|
|
||||||
def _process_list_value(name, parse_fn, var_type, m_dict, values,
|
|
||||||
results_dictionary):
|
|
||||||
"""Update results_dictionary from a list of values.
|
|
||||||
|
|
||||||
Used to update results_dictionary to be returned by parse_values when
|
|
||||||
encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
|
|
||||||
|
|
||||||
Mutates results_dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of variable in assignment ("arr").
|
|
||||||
parse_fn: Function for parsing individual values.
|
|
||||||
var_type: Type of named variable.
|
|
||||||
m_dict: Dictionary constructed from regex parsing.
|
|
||||||
m_dict['val']: RHS value (scalar)
|
|
||||||
values: Full expression being parsed
|
|
||||||
results_dictionary: The dictionary being updated for return by the parsing
|
|
||||||
function.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the name has an index or the values cannot be parsed.
|
|
||||||
"""
|
|
||||||
if m_dict['index'] is not None:
|
|
||||||
raise ValueError('Assignment of a list to a list index.')
|
|
||||||
elements = filter(None, re.split('[ ,]', m_dict['vals']))
|
|
||||||
# Make sure the name hasn't already been assigned a value
|
|
||||||
if name in results_dictionary:
|
|
||||||
raise _reuse_fail(name, values)
|
|
||||||
try:
|
|
||||||
results_dictionary[name] = [parse_fn(e) for e in elements]
|
|
||||||
except ValueError:
|
|
||||||
_parse_fail(name, var_type, m_dict['vals'], values)
|
|
||||||
|
|
||||||
|
|
||||||
def _cast_to_type_if_compatible(name, param_type, value):
|
|
||||||
"""Cast hparam to the provided type, if compatible.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of the hparam to be cast.
|
|
||||||
param_type: The type of the hparam.
|
|
||||||
value: The value to be cast, if compatible.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The result of casting `value` to `param_type`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the type of `value` is not compatible with param_type.
|
|
||||||
* If `param_type` is a string type, but `value` is not.
|
|
||||||
* If `param_type` is a boolean, but `value` is not, or vice versa.
|
|
||||||
* If `param_type` is an integer type, but `value` is not.
|
|
||||||
* If `param_type` is a float type, but `value` is not a numeric type.
|
|
||||||
"""
|
|
||||||
fail_msg = ("Could not cast hparam '%s' of type '%s' from value %r" %
|
|
||||||
(name, param_type, value))
|
|
||||||
|
|
||||||
# Some callers use None, for which we can't do any casting/checking. :(
|
|
||||||
if issubclass(param_type, type(None)):
|
|
||||||
return value
|
|
||||||
|
|
||||||
# Avoid converting a non-string type to a string.
|
|
||||||
if (issubclass(param_type, (six.string_types, six.binary_type)) and
|
|
||||||
not isinstance(value, (six.string_types, six.binary_type))):
|
|
||||||
raise ValueError(fail_msg)
|
|
||||||
|
|
||||||
# Avoid converting a number or string type to a boolean or vice versa.
|
|
||||||
if issubclass(param_type, bool) != isinstance(value, bool):
|
|
||||||
raise ValueError(fail_msg)
|
|
||||||
|
|
||||||
# Avoid converting float to an integer (the reverse is fine).
|
|
||||||
if (issubclass(param_type, numbers.Integral) and
|
|
||||||
not isinstance(value, numbers.Integral)):
|
|
||||||
raise ValueError(fail_msg)
|
|
||||||
|
|
||||||
# Avoid converting a non-numeric type to a numeric type.
|
|
||||||
if (issubclass(param_type, numbers.Number) and
|
|
||||||
not isinstance(value, numbers.Number)):
|
|
||||||
raise ValueError(fail_msg)
|
|
||||||
|
|
||||||
return param_type(value)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_values(values, type_map):
|
|
||||||
"""Parses hyperparameter values from a string into a python map.
|
|
||||||
|
|
||||||
`values` is a string containing comma-separated `name=value` pairs.
|
|
||||||
For each pair, the value of the hyperparameter named `name` is set to
|
|
||||||
`value`.
|
|
||||||
|
|
||||||
If a hyperparameter name appears multiple times in `values`, a ValueError
|
|
||||||
is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
|
|
||||||
|
|
||||||
If a hyperparameter name in both an index assignment and scalar assignment,
|
|
||||||
a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
|
|
||||||
|
|
||||||
The hyperparameter name may contain '.' symbols, which will result in an
|
|
||||||
attribute name that is only accessible through the getattr and setattr
|
|
||||||
functions. (And must be first explicit added through add_hparam.)
|
|
||||||
|
|
||||||
WARNING: Use of '.' in your variable names is allowed, but is not well
|
|
||||||
supported and not recommended.
|
|
||||||
|
|
||||||
The `value` in `name=value` must follows the syntax according to the
|
|
||||||
type of the parameter:
|
|
||||||
|
|
||||||
* Scalar integer: A Python-parsable integer point value. E.g.: 1,
|
|
||||||
100, -12.
|
|
||||||
* Scalar float: A Python-parsable floating point value. E.g.: 1.0,
|
|
||||||
-.54e89.
|
|
||||||
* Boolean: Either true or false.
|
|
||||||
* Scalar string: A non-empty sequence of characters, excluding comma,
|
|
||||||
spaces, and square brackets. E.g.: foo, bar_1.
|
|
||||||
* List: A comma separated list of scalar values of the parameter type
|
|
||||||
enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
|
|
||||||
|
|
||||||
When index assignment is used, the corresponding type_map key should be the
|
|
||||||
list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
|
|
||||||
"arr[1]").
|
|
||||||
|
|
||||||
Args:
|
|
||||||
values: String. Comma separated list of `name=value` pairs where
|
|
||||||
'value' must follow the syntax described above.
|
|
||||||
type_map: A dictionary mapping hyperparameter names to types. Note every
|
|
||||||
parameter name in values must be a key in type_map. The values must
|
|
||||||
conform to the types indicated, where a value V is said to conform to a
|
|
||||||
type T if either V has type T, or V is a list of elements of type T.
|
|
||||||
Hence, for a multidimensional parameter 'x' taking float values,
|
|
||||||
'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A python map mapping each name to either:
|
|
||||||
* A scalar value.
|
|
||||||
* A list of scalar values.
|
|
||||||
* A dictionary mapping index numbers to scalar values.
|
|
||||||
(e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If there is a problem with input.
|
|
||||||
* If `values` cannot be parsed.
|
|
||||||
* If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
|
|
||||||
* If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
|
|
||||||
'a[1]=1,a[1]=2', or 'a=1,a=[1]')
|
|
||||||
"""
|
|
||||||
results_dictionary = {}
|
|
||||||
pos = 0
|
|
||||||
while pos < len(values):
|
|
||||||
m = PARAM_RE.match(values, pos)
|
|
||||||
if not m:
|
|
||||||
raise ValueError('Malformed hyperparameter value: %s' %
|
|
||||||
values[pos:])
|
|
||||||
# Check that there is a comma between parameters and move past it.
|
|
||||||
pos = m.end()
|
|
||||||
# Parse the values.
|
|
||||||
m_dict = m.groupdict()
|
|
||||||
name = m_dict['name']
|
|
||||||
if name not in type_map:
|
|
||||||
raise ValueError('Unknown hyperparameter type for %s' % name)
|
|
||||||
type_ = type_map[name]
|
|
||||||
|
|
||||||
# Set up correct parsing function (depending on whether type_ is a bool)
|
|
||||||
if type_ == bool:
|
|
||||||
|
|
||||||
def parse_bool(value):
|
|
||||||
if value in ['true', 'True']:
|
|
||||||
return True
|
|
||||||
elif value in ['false', 'False']:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
return bool(int(value))
|
|
||||||
except ValueError:
|
|
||||||
_parse_fail(name, type_, value, values)
|
|
||||||
|
|
||||||
parse = parse_bool
|
|
||||||
else:
|
|
||||||
parse = type_
|
|
||||||
|
|
||||||
# If a singe value is provided
|
|
||||||
if m_dict['val'] is not None:
|
|
||||||
_process_scalar_value(name, parse, type_, m_dict, values,
|
|
||||||
results_dictionary)
|
|
||||||
|
|
||||||
# If the assigned value is a list:
|
|
||||||
elif m_dict['vals'] is not None:
|
|
||||||
_process_list_value(name, parse, type_, m_dict, values,
|
|
||||||
results_dictionary)
|
|
||||||
|
|
||||||
else: # Not assigned a list or value
|
|
||||||
_parse_fail(name, type_, '', values)
|
|
||||||
|
|
||||||
return results_dictionary
|
|
||||||
|
|
||||||
|
|
||||||
class HParams(object):
|
|
||||||
"""Class to hold a set of hyperparameters as name-value pairs.
|
|
||||||
|
|
||||||
A `HParams` object holds hyperparameters used to build and train a model,
|
|
||||||
such as the number of hidden units in a neural net layer or the learning rate
|
|
||||||
to use when training.
|
|
||||||
|
|
||||||
You first create a `HParams` object by specifying the names and values of the
|
|
||||||
hyperparameters.
|
|
||||||
|
|
||||||
To make them easily accessible the parameter names are added as direct
|
|
||||||
attributes of the class. A typical usage is as follows:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Create a HParams object specifying names and values of the model
|
|
||||||
# hyperparameters:
|
|
||||||
hparams = HParams(learning_rate=0.1, num_hidden_units=100)
|
|
||||||
|
|
||||||
# The hyperparameter are available as attributes of the HParams object:
|
|
||||||
hparams.learning_rate ==> 0.1
|
|
||||||
hparams.num_hidden_units ==> 100
|
|
||||||
```
|
|
||||||
|
|
||||||
Hyperparameters have type, which is inferred from the type of their value
|
|
||||||
passed at construction type. The currently supported types are: integer,
|
|
||||||
float, boolean, string, and list of integer, float, boolean, or string.
|
|
||||||
|
|
||||||
You can override hyperparameter values by calling the
|
|
||||||
[`parse()`](#HParams.parse) method, passing a string of comma separated
|
|
||||||
`name=value` pairs. This is intended to make it possible to override
|
|
||||||
any hyperparameter values from a single command-line flag to which
|
|
||||||
the user passes 'hyper-param=value' pairs. It avoids having to define
|
|
||||||
one flag for each hyperparameter.
|
|
||||||
|
|
||||||
The syntax expected for each value depends on the type of the parameter.
|
|
||||||
See `parse()` for a description of the syntax.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Define a command line flag to pass name=value pairs.
|
|
||||||
# For example using argparse:
|
|
||||||
import argparse
|
|
||||||
parser = argparse.ArgumentParser(description='Train my model.')
|
|
||||||
parser.add_argument('--hparams', type=str,
|
|
||||||
help='Comma separated list of "name=value" pairs.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
...
|
|
||||||
def my_program():
|
|
||||||
# Create a HParams object specifying the names and values of the
|
|
||||||
# model hyperparameters:
|
|
||||||
hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100,
|
|
||||||
activations=['relu', 'tanh'])
|
|
||||||
|
|
||||||
# Override hyperparameters values by parsing the command line
|
|
||||||
hparams.parse(args.hparams)
|
|
||||||
|
|
||||||
# If the user passed `--hparams=learning_rate=0.3` on the command line
|
|
||||||
# then 'hparams' has the following attributes:
|
|
||||||
hparams.learning_rate ==> 0.3
|
|
||||||
hparams.num_hidden_units ==> 100
|
|
||||||
hparams.activations ==> ['relu', 'tanh']
|
|
||||||
|
|
||||||
# If the hyperparameters are in json format use parse_json:
|
|
||||||
hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
_HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
|
|
||||||
|
|
||||||
def __init__(self, hparam_def=None, model_structure=None, **kwargs):
|
|
||||||
"""Create an instance of `HParams` from keyword arguments.
|
|
||||||
|
|
||||||
The keyword arguments specify name-values pairs for the hyperparameters.
|
|
||||||
The parameter types are inferred from the type of the values passed.
|
|
||||||
|
|
||||||
The parameter names are added as attributes of `HParams` object, so they
|
|
||||||
can be accessed directly with the dot notation `hparams._name_`.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Define 3 hyperparameters: 'learning_rate' is a float parameter,
|
|
||||||
# 'num_hidden_units' an integer parameter, and 'activation' a string
|
|
||||||
# parameter.
|
|
||||||
hparams = tf.HParams(
|
|
||||||
learning_rate=0.1, num_hidden_units=100, activation='relu')
|
|
||||||
|
|
||||||
hparams.activation ==> 'relu'
|
|
||||||
```
|
|
||||||
|
|
||||||
Note that a few names are reserved and cannot be used as hyperparameter
|
|
||||||
names. If you use one of the reserved name the constructor raises a
|
|
||||||
`ValueError`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef
|
|
||||||
protocol buffer. If provided, this object is initialized by
|
|
||||||
deserializing hparam_def. Otherwise **kwargs is used.
|
|
||||||
model_structure: An instance of ModelStructure, defining the feature
|
|
||||||
crosses to be used in the Trial.
|
|
||||||
**kwargs: Key-value pairs where the key is the hyperparameter name and
|
|
||||||
the value is the value for the parameter.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If both `hparam_def` and initialization values are provided,
|
|
||||||
or if one of the arguments is invalid.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Register the hyperparameters and their type in _hparam_types.
|
|
||||||
# This simplifies the implementation of parse().
|
|
||||||
# _hparam_types maps the parameter name to a tuple (type, bool).
|
|
||||||
# The type value is the type of the parameter for scalar hyperparameters,
|
|
||||||
# or the type of the list elements for multidimensional hyperparameters.
|
|
||||||
# The bool value is True if the value is a list, False otherwise.
|
|
||||||
self._hparam_types = {}
|
|
||||||
self._model_structure = model_structure
|
|
||||||
if hparam_def:
|
|
||||||
## self._init_from_proto(hparam_def)
|
|
||||||
## if kwargs:
|
|
||||||
## raise ValueError('hparam_def and initialization values are '
|
|
||||||
## 'mutually exclusive')
|
|
||||||
raise ValueError('hparam_def has been disabled in this version')
|
|
||||||
else:
|
|
||||||
for name, value in six.iteritems(kwargs):
|
|
||||||
self.add_hparam(name, value)
|
|
||||||
|
|
||||||
## def _init_from_proto(self, hparam_def):
|
|
||||||
## """Creates a new HParams from `HParamDef` protocol buffer.
|
|
||||||
##
|
|
||||||
## Args:
|
|
||||||
## hparam_def: `HParamDef` protocol buffer.
|
|
||||||
## """
|
|
||||||
## assert isinstance(hparam_def, hparam_pb2.HParamDef)
|
|
||||||
## for name, value in hparam_def.hparam.items():
|
|
||||||
## kind = value.WhichOneof('kind')
|
|
||||||
## if kind.endswith('_value'):
|
|
||||||
## # Single value.
|
|
||||||
## if kind.startswith('int64'):
|
|
||||||
## # Setting attribute value to be 'int' to ensure the type is compatible
|
|
||||||
## # with both Python2 and Python3.
|
|
||||||
## self.add_hparam(name, int(getattr(value, kind)))
|
|
||||||
## elif kind.startswith('bytes'):
|
|
||||||
## # Setting attribute value to be 'str' to ensure the type is compatible
|
|
||||||
## # with both Python2 and Python3. UTF-8 encoding is assumed.
|
|
||||||
## self.add_hparam(name, compat.as_str(getattr(value, kind)))
|
|
||||||
## else:
|
|
||||||
## self.add_hparam(name, getattr(value, kind))
|
|
||||||
## else:
|
|
||||||
## # List of values.
|
|
||||||
## if kind.startswith('int64'):
|
|
||||||
## # Setting attribute value to be 'int' to ensure the type is compatible
|
|
||||||
## # with both Python2 and Python3.
|
|
||||||
## self.add_hparam(name, [int(v) for v in getattr(value, kind).value])
|
|
||||||
## elif kind.startswith('bytes'):
|
|
||||||
## # Setting attribute value to be 'str' to ensure the type is compatible
|
|
||||||
## # with both Python2 and Python3. UTF-8 encoding is assumed.
|
|
||||||
## self.add_hparam(
|
|
||||||
## name, [compat.as_str(v) for v in getattr(value, kind).value])
|
|
||||||
## else:
|
|
||||||
## self.add_hparam(name, [v for v in getattr(value, kind).value])
|
|
||||||
|
|
||||||
def add_hparam(self, name, value):
|
|
||||||
"""Adds {name, value} pair to hyperparameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of the hyperparameter.
|
|
||||||
value: Value of the hyperparameter. Can be one of the following types:
|
|
||||||
int, float, string, int list, float list, or string list.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if one of the arguments is invalid.
|
|
||||||
"""
|
|
||||||
# Keys in kwargs are unique, but 'name' could the name of a pre-existing
|
|
||||||
# attribute of this object. In that case we refuse to use it as a
|
|
||||||
# hyperparameter name.
|
|
||||||
if getattr(self, name, None) is not None:
|
|
||||||
raise ValueError('Hyperparameter name is reserved: %s' % name)
|
|
||||||
if isinstance(value, (list, tuple)):
|
|
||||||
if not value:
|
|
||||||
raise ValueError(
|
|
||||||
'Multi-valued hyperparameters cannot be empty: %s' % name)
|
|
||||||
self._hparam_types[name] = (type(value[0]), True)
|
|
||||||
else:
|
|
||||||
self._hparam_types[name] = (type(value), False)
|
|
||||||
setattr(self, name, value)
|
|
||||||
|
|
||||||
def set_hparam(self, name, value):
|
|
||||||
"""Set the value of an existing hyperparameter.
|
|
||||||
|
|
||||||
This function verifies that the type of the value matches the type of the
|
|
||||||
existing hyperparameter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of the hyperparameter.
|
|
||||||
value: New value of the hyperparameter.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If there is a type mismatch.
|
|
||||||
"""
|
|
||||||
param_type, is_list = self._hparam_types[name]
|
|
||||||
if isinstance(value, list):
|
|
||||||
if not is_list:
|
|
||||||
raise ValueError(
|
|
||||||
'Must not pass a list for single-valued parameter: %s' %
|
|
||||||
name)
|
|
||||||
setattr(self, name, [
|
|
||||||
_cast_to_type_if_compatible(name, param_type, v) for v in value
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
if is_list:
|
|
||||||
raise ValueError(
|
|
||||||
'Must pass a list for multi-valued parameter: %s.' % name)
|
|
||||||
setattr(self, name,
|
|
||||||
_cast_to_type_if_compatible(name, param_type, value))
|
|
||||||
|
|
||||||
def del_hparam(self, name):
|
|
||||||
"""Removes the hyperparameter with key 'name'.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of the hyperparameter.
|
|
||||||
"""
|
|
||||||
if hasattr(self, name):
|
|
||||||
delattr(self, name)
|
|
||||||
del self._hparam_types[name]
|
|
||||||
|
|
||||||
def parse(self, values):
|
|
||||||
"""Override hyperparameter values, parsing new values from a string.
|
|
||||||
|
|
||||||
See parse_values for more detail on the allowed format for values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
values: String. Comma separated list of `name=value` pairs where
|
|
||||||
'value' must follow the syntax described above.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The `HParams` instance.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `values` cannot be parsed.
|
|
||||||
"""
|
|
||||||
type_map = dict()
|
|
||||||
for name, t in self._hparam_types.items():
|
|
||||||
param_type, _ = t
|
|
||||||
type_map[name] = param_type
|
|
||||||
|
|
||||||
values_map = parse_values(values, type_map)
|
|
||||||
return self.override_from_dict(values_map)
|
|
||||||
|
|
||||||
def override_from_dict(self, values_dict):
|
|
||||||
"""Override hyperparameter values, parsing new values from a dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
values_dict: Dictionary of name:value pairs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The `HParams` instance.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `values_dict` cannot be parsed.
|
|
||||||
"""
|
|
||||||
for name, value in values_dict.items():
|
|
||||||
self.set_hparam(name, value)
|
|
||||||
return self
|
|
||||||
|
|
||||||
## @deprecation.deprecated(None, 'Use `override_from_dict`.')
|
|
||||||
|
|
||||||
def set_from_map(self, values_map):
|
|
||||||
"""DEPRECATED. Use override_from_dict."""
|
|
||||||
return self.override_from_dict(values_dict=values_map)
|
|
||||||
|
|
||||||
def set_model_structure(self, model_structure):
|
|
||||||
self._model_structure = model_structure
|
|
||||||
|
|
||||||
def get_model_structure(self):
|
|
||||||
return self._model_structure
|
|
||||||
|
|
||||||
def to_json(self, indent=None, separators=None, sort_keys=False):
|
|
||||||
"""Serializes the hyperparameters into JSON.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
indent: If a non-negative integer, JSON array elements and object members
|
|
||||||
will be pretty-printed with that indent level. An indent level of 0, or
|
|
||||||
negative, will only insert newlines. `None` (the default) selects the
|
|
||||||
most compact representation.
|
|
||||||
separators: Optional `(item_separator, key_separator)` tuple. Default is
|
|
||||||
`(', ', ': ')`.
|
|
||||||
sort_keys: If `True`, the output dictionaries will be sorted by key.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A JSON string.
|
|
||||||
"""
|
|
||||||
return json.dumps(
|
|
||||||
self.values(),
|
|
||||||
indent=indent,
|
|
||||||
separators=separators,
|
|
||||||
sort_keys=sort_keys)
|
|
||||||
|
|
||||||
def parse_json(self, values_json):
|
|
||||||
"""Override hyperparameter values, parsing new values from a json object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
values_json: String containing a json object of name:value pairs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The `HParams` instance.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `values_json` cannot be parsed.
|
|
||||||
"""
|
|
||||||
values_map = json.loads(values_json)
|
|
||||||
return self.override_from_dict(values_map)
|
|
||||||
|
|
||||||
def values(self):
|
|
||||||
"""Return the hyperparameter values as a Python dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary with hyperparameter names as keys. The values are the
|
|
||||||
hyperparameter values.
|
|
||||||
"""
|
|
||||||
return {n: getattr(self, n) for n in self._hparam_types.keys()}
|
|
||||||
|
|
||||||
def get(self, key, default=None):
|
|
||||||
"""Returns the value of `key` if it exists, else `default`."""
|
|
||||||
if key in self._hparam_types:
|
|
||||||
# Ensure that default is compatible with the parameter type.
|
|
||||||
if default is not None:
|
|
||||||
param_type, is_param_list = self._hparam_types[key]
|
|
||||||
type_str = 'list<%s>' % param_type if is_param_list else str(
|
|
||||||
param_type)
|
|
||||||
fail_msg = ("Hparam '%s' of type '%s' is incompatible with "
|
|
||||||
'default=%s' % (key, type_str, default))
|
|
||||||
|
|
||||||
is_default_list = isinstance(default, list)
|
|
||||||
if is_param_list != is_default_list:
|
|
||||||
raise ValueError(fail_msg)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if is_default_list:
|
|
||||||
for value in default:
|
|
||||||
_cast_to_type_if_compatible(key, param_type, value)
|
|
||||||
else:
|
|
||||||
_cast_to_type_if_compatible(key, param_type, default)
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError('%s. %s' % (fail_msg, e))
|
|
||||||
|
|
||||||
return getattr(self, key)
|
|
||||||
|
|
||||||
return default
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
|
||||||
return key in self._hparam_types
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return str(sorted(self.values().items()))
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return '%s(%s)' % (type(self).__name__, self.__str__())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_kind_name(param_type, is_list):
|
|
||||||
"""Returns the field name given parameter type and is_list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
param_type: Data type of the hparam.
|
|
||||||
is_list: Whether this is a list.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representation of the field name.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If parameter type is not recognized.
|
|
||||||
"""
|
|
||||||
if issubclass(param_type, bool):
|
|
||||||
# This check must happen before issubclass(param_type, six.integer_types),
|
|
||||||
# since Python considers bool to be a subclass of int.
|
|
||||||
typename = 'bool'
|
|
||||||
elif issubclass(param_type, six.integer_types):
|
|
||||||
# Setting 'int' and 'long' types to be 'int64' to ensure the type is
|
|
||||||
# compatible with both Python2 and Python3.
|
|
||||||
typename = 'int64'
|
|
||||||
elif issubclass(param_type, (six.string_types, six.binary_type)):
|
|
||||||
# Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
|
|
||||||
# compatible with both Python2 and Python3.
|
|
||||||
typename = 'bytes'
|
|
||||||
elif issubclass(param_type, float):
|
|
||||||
typename = 'float'
|
|
||||||
else:
|
|
||||||
raise ValueError('Unsupported parameter type: %s' % str(param_type))
|
|
||||||
|
|
||||||
suffix = 'list' if is_list else 'value'
|
|
||||||
return '_'.join([typename, suffix])
|
|
||||||
|
|
||||||
|
|
||||||
## def to_proto(self, export_scope=None): # pylint: disable=unused-argument
|
|
||||||
## """Converts a `HParams` object to a `HParamDef` protocol buffer.
|
|
||||||
##
|
|
||||||
## Args:
|
|
||||||
## export_scope: Optional `string`. Name scope to remove.
|
|
||||||
##
|
|
||||||
## Returns:
|
|
||||||
## A `HParamDef` protocol buffer.
|
|
||||||
## """
|
|
||||||
## hparam_proto = hparam_pb2.HParamDef()
|
|
||||||
## for name in self._hparam_types:
|
|
||||||
## # Parse the values.
|
|
||||||
## param_type, is_list = self._hparam_types.get(name, (None, None))
|
|
||||||
## kind = HParams._get_kind_name(param_type, is_list)
|
|
||||||
##
|
|
||||||
## if is_list:
|
|
||||||
## if kind.startswith('bytes'):
|
|
||||||
## v_list = [compat.as_bytes(v) for v in getattr(self, name)]
|
|
||||||
## else:
|
|
||||||
## v_list = [v for v in getattr(self, name)]
|
|
||||||
## getattr(hparam_proto.hparam[name], kind).value.extend(v_list)
|
|
||||||
## else:
|
|
||||||
## v = getattr(self, name)
|
|
||||||
## if kind.startswith('bytes'):
|
|
||||||
## v = compat.as_bytes(getattr(self, name))
|
|
||||||
## setattr(hparam_proto.hparam[name], kind, v)
|
|
||||||
##
|
|
||||||
## return hparam_proto
|
|
||||||
|
|
||||||
## @staticmethod
|
|
||||||
## def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument
|
|
||||||
## return HParams(hparam_def=hparam_def)
|
|
||||||
|
|
||||||
## ops.register_proto_function(
|
|
||||||
## 'hparams',
|
|
||||||
## proto_type=hparam_pb2.HParamDef,
|
|
||||||
## to_proto=HParams.to_proto,
|
|
||||||
## from_proto=HParams.from_proto)
|
|
|
@ -1,8 +0,0 @@
|
||||||
Source: hparam.py copied from tensorflow v1.12.0.
|
|
||||||
|
|
||||||
https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
|
|
||||||
|
|
||||||
with the following:
|
|
||||||
wget https://github.com/tensorflow/tensorflow/raw/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
|
|
||||||
|
|
||||||
Once all other tensorflow dependencies of these file are removed, the class keeps its goal. Functions not available due to this process are not used in this project.
|
|
|
@ -19,7 +19,7 @@ import paddle
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
|
|
||||||
from weight_norm import Conv2D, Conv2DTranspose
|
from .weight_norm import Conv2D, Conv2DTranspose
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(dg.Layer):
|
class Conv1D(dg.Layer):
|
||||||
|
|
|
@ -18,8 +18,8 @@ import paddle.fluid.dygraph as dg
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import conv
|
from .import conv
|
||||||
import weight_norm as weight_norm
|
from . import weight_norm
|
||||||
|
|
||||||
|
|
||||||
def FC(name_scope,
|
def FC(name_scope,
|
||||||
|
|
Loading…
Reference in New Issue