Source code for nengo_dl.objectives

"""
Some common objective functions (for use with the ``objective`` argument in
`.Simulator.train` or `.Simulator.loss`).
"""

import tensorflow as tf


[docs]def mse(outputs, targets): """ Compute Mean Squared Error between given outputs and targets. If any values in ``targets`` are ``nan``, that will be treated as zero error for those elements. Parameters ---------- outputs : ``tf.Tensor`` Output values from a Probe in a network. targets : ``tf.Tensor`` Target values for a Probe in a network. Returns ------- mse : ``tf.Tensor`` Tensor representing the mean squared error. """ targets = tf.where(tf.is_nan(targets), outputs, targets) return tf.reduce_mean(tf.square(targets - outputs))
[docs]class Regularize: """ An objective function to apply regularization penalties. Parameters ---------- order : int or str Order of the regularization norm (e.g. ``1`` for L1 norm, ``2`` for L2 norm). See https://www.tensorflow.org/api_docs/python/tf/norm for a full description of the possible values for this parameter. axis : int or None The axis of the probed signal along which to compute norm. If None (the default), the signal is flattened and the norm is computed across the resulting vector. Note that these are only the axes with respect to the output on a single timestep (i.e. batch/time dimensions are not included). weight : float Scaling weight to apply to regularization penalty. Notes ----- The mean will be computed across all the non-``axis`` dimensions after computing the norm (including batch/time) in order to compute the overall objective value. """ def __init__(self, order=2, axis=None, weight=None): self.order = order self.axis = axis self.weight = weight def __call__(self, x): if self.axis is None: if x.get_shape().ndims > 3: # flatten signal (keeping batch/time dimension) x = tf.reshape(x, tf.concat([tf.shape(x)[:2], (-1,)], axis=0)) axis = 2 else: axis = self.axis + 2 output = tf.reduce_mean(tf.norm(x, axis=axis, ord=self.order)) if self.weight is not None: output *= self.weight return output