Source code for nengo_dl.signals

from collections import defaultdict, OrderedDict, Mapping
import logging

from nengo.builder.signal import Signal
from nengo.exceptions import BuildError
import numpy as np
import tensorflow as tf

logger = logging.getLogger(__name__)


[docs]class TensorSignal(object): """Represents a tensor as an indexed view into a base array. Parameters ---------- indices : tuple or list or :class:`~numpy:numpy.ndarray` of int Indices along the first axis of the base array corresponding to the data for this signal key : object Key mapping to the base array that contains the data for this signal dtype : :class:`~numpy:numpy.dtype` dtype of the values represented by this signal shape : tuple of int View shape of this signal (may differ from shape of base array) minibatch_size : int If not None then this signal contains a minibatch dimension with the given size constant : callable A function that returns a TensorFlow constant (will be provided by :meth:`.signals.SignalDict.get_tensor_signal`) label : str, optional Name for this signal, used to make debugging easier """ def __init__(self, indices, key, dtype, shape, minibatch_size, constant, label="TensorSignal"): # make indices read-only assert isinstance(indices, (tuple, list, np.ndarray)) self._indices = np.asarray(indices) self._indices.flags.writeable = False self._tf_shape = None self._tf_indices = None self._tf_slice = -1 self.key = key self.dtype = dtype self.shape = shape self.minibatch_size = minibatch_size self.constant = constant self.label = label @property def indices(self): return self._indices @indices.setter def indices(self, _): raise BuildError("Indices are read only") @property def ndim(self): return len(self.shape) def __repr__(self): return "TensorSignal(key=%s, shape=%s, label=%s)" % ( self.key, self.shape, self.label)
[docs] def __getitem__(self, indices): """Create a new TensorSignal representing a subset (slice or advanced indexing) of the indices of this TensorSignal. Parameters ---------- indices : slice or list of int The desired subset of the indices in this TensorSignal Returns ------- :class:`.signals.TensorSignal` A new TensorSignal representing the subset of this TensorSignal """ if indices is Ellipsis or indices is None: return self new_indices = self.indices[indices] return TensorSignal( new_indices, self.key, self.dtype, (len(new_indices),) + self.shape[1:], self.minibatch_size, self.constant, label=self.label + ".slice")
[docs] def reshape(self, shape): """Create a new TensorSignal representing a reshaped view of the same data in this TensorSignal (size of data must remain unchanged). Parameters ---------- shape : tuple of int New shape for the signal (one dimension can be -1 to indicate an inferred dimension size, as in numpy) Returns ------- :class:`.signals.TensorSignal` New TensorSignal representing the same data as this signal but with the given shape """ # replace -1 with inferred dimension if shape.count(-1) > 1: raise BuildError("Only one inferred dimension allowed in reshape") elif shape.count(-1) == 1: n_elem = np.prod(self.shape) n_shape = int(np.prod([x for x in shape if x != -1])) if n_elem % n_shape != 0: raise BuildError("No valid length for inferred dimension") shape = tuple(x if x != -1 else n_elem // n_shape for x in shape) else: if np.prod(shape) != np.prod(self.shape): raise BuildError("Number of elements don't match in reshape") return TensorSignal( self.indices, self.key, self.dtype, shape, self.minibatch_size, self.constant, label=self.label + ".reshape(%s)" % (shape,))
[docs] def broadcast(self, axis, length): """Add a new dimension by broadcasting this signal along ``axis`` for the given length. Parameters ---------- axis : 0 or -1 Where to insert the new dimension (currently only supports either the beginning or end of the array) length : int The number of times to duplicate signal along the broadcast dimension Returns ------- :class:`.signals.TensorSignal` TensorSignal with new broadcasted shape """ assert axis in (0, -1) # this only works on vectors assert self.ndim == 1 and not self.minibatched indices = self.indices indices = np.stack([indices] * length, axis=axis) indices = np.reshape(indices, (-1,)) if axis == -1: display_shape = self.shape + (length,) else: display_shape = (length,) + self.shape return TensorSignal( indices, self.key, self.dtype, display_shape, self.minibatch_size, self.constant, label=self.label + ".broadcast(%d, %d)" % (axis, length))
@property def tf_shape(self): if self._tf_shape is None: self._tf_shape = tf.constant(self.full_shape, dtype=tf.int32) return self._tf_shape @property def tf_indices(self): if self._tf_indices is None: self._tf_indices = self.constant(self.indices, dtype=tf.int32) return self._tf_indices @property def tf_slice(self): if self._tf_slice == -1: start = self.indices[0] stop = self.indices[-1] + 1 step = (self.indices[1] - self.indices[0] if len(self.indices) > 1 else 1) if step != 0 and np.array_equal(self.indices, np.arange(start, stop, step)): self._tf_slice = (tf.constant([start]), tf.constant([stop]), tf.constant([step])) else: self._tf_slice = None return self._tf_slice @property def full_shape(self): """Shape including the minibatch dimension.""" return (self.shape + (self.minibatch_size,) if self.minibatched else self.shape) @property def minibatched(self): """Whether or not this TensorSignal contains a minibatch dimension.""" return self.minibatch_size is not None
[docs]class SignalDict(Mapping): """Handles the mapping from :class:`~nengo:nengo.builder.Signal` to ``tf.Tensor``. Takes care of gather/scatter logic to read/write signals within the base arrays. Parameters ---------- dtype : ``tf.DType`` Floating point precision used in signals minibatch_size : int Number of items in each minibatch """ def __init__(self, dtype, minibatch_size): self.dtype = dtype self.sig_map = {} self.minibatch_size = minibatch_size self.bases = None self.reads_by_base = defaultdict(list) self.gather_bases = [] self.internal_vars = OrderedDict() self.constant_phs = {} # logging self.read_types = defaultdict(int) self.write_types = defaultdict(int)
[docs] def scatter(self, dst, val, mode="update"): """Updates the base data corresponding to ``dst``. Parameters ---------- dst : :class:`.TensorSignal` Signal indicating the data to be modified in base array val : ``tf.Tensor`` Update data (same shape as ``dst``, i.e. a dense array <= the size of the base array) mode : "update" or "inc" Overwrite/add the data at ``dst`` with ``val`` """ if val.dtype.is_floating and val.dtype.base_dtype != self.dtype: raise BuildError("Tensor detected with wrong dtype (%s), should " "be %s." % (val.dtype.base_dtype, self.dtype)) # align val shape with dst base shape self.bases[dst.key].get_shape().assert_is_fully_defined() val.get_shape().assert_is_fully_defined() dst_shape = ((dst.shape[0],) + tuple(self.bases[dst.key].get_shape().as_list()[1:])) if val.get_shape() != dst_shape: val = tf.reshape(val, dst.tf_shape) logger.debug("scatter") logger.debug("values %s", val) logger.debug("dst %s", dst) logger.debug("indices %s", dst.indices) logger.debug("dst base %s", self.bases[dst.key]) logger.debug("reads_by_base %s", self.reads_by_base[self.bases[dst.key]]) # make sure that any reads to the target signal happen before this # write (note: this is only any reads that have happened since the # last write, since each write changes the base array object) with tf.control_dependencies(self.reads_by_base[self.bases[dst.key]]): self.bases[dst.key] = self._scatter_f_var(dst, val, mode=mode) # update reads_by_base. the general workflow is # gather -> computation -> scatter # so when we get a scatter, we assume that that value indicates that # all the previous gathers are complete. so we block any writes to # those bases on the scatter value, to be sure that the # computation step is complete before the values can be overwritten for b in self.gather_bases: self.reads_by_base[b] += [self.bases[dst.key]] self.gather_bases = [] logger.debug("new dst base %s", self.bases[dst.key])
def _scatter_f_var(self, dst, src, mode="update"): # create a temporary variable for dst so that we can use the sparse # variable updates. despite this looking incredibly inefficient, it is # actually faster than the scatter_nd approach # from tensorflow.python.ops import gen_state_ops # var = gen_state_ops._temporary_variable( # self.bases[dst.key].get_shape(), self.bases[dst.key].dtype) # var_name = var.op.name # var = tf.assign(var, self.bases[dst.key]) var = self.bases[dst.key] if (dst.tf_slice is not None and var.get_shape().is_compatible_with(src.get_shape()) and dst.indices[0] == 0 and dst.indices[-1] == var.get_shape()[0].value - 1 and len(dst.indices) == var.get_shape()[0]): if mode == "inc": result = tf.assign_add(var, src, use_locking=False) self.write_types["assign_add"] += 1 else: result = tf.assign(var, src, use_locking=False) self.write_types["assign"] += 1 elif mode == "inc": result = tf.scatter_add(var, dst.tf_indices, src, use_locking=False) self.write_types["scatter_add"] += 1 else: result = tf.scatter_update(var, dst.tf_indices, src, use_locking=False) self.write_types["scatter_update"] += 1 # result = gen_state_ops._destroy_temporary_variable(var, var_name) return result
[docs] def gather(self, src, force_copy=False): """Fetches the data corresponding to ``src`` from the base array. Parameters ---------- src : :class:`.TensorSignal` Signal indicating the data to be read from base array force_copy : bool, optional If True, always perform a gather, not a slice (this forces a copy). Note that setting ``force_copy=False`` does not guarantee that a copy won't be performed. Returns ------- ``tf.Tensor`` Tensor object corresponding to a dense subset of data from the base array """ logger.debug("gather") logger.debug("src %s", src) logger.debug("indices %s", src.indices) logger.debug("src base %s", self.bases[src.key]) var = self.bases[src.key] # we prefer to get the data via `strided_slice` or `identity` if # possible, as it is more efficient if force_copy or src.tf_slice is None: result = tf.gather(var, src.tf_indices) self.read_types["gather"] += 1 elif (src.indices[0] == 0 and src.indices[-1] == var.get_shape()[0].value - 1 and len(src.indices) == var.get_shape()[0]): result = var self.read_types["identity"] += 1 else: result = tf.strided_slice(var, *src.tf_slice) self.read_types["strided_slice"] += 1 # reshape the data according to the shape set in `src`, if there is # one, otherwise keep the shape of the base array if result.get_shape() != src.full_shape: result = tf.reshape(result, src.tf_shape) # for some reason the shape inference doesn't work in some cases result.set_shape(src.full_shape) # whenever we read from an array we use this to mark it as "read" # (so that any future writes to the array will be scheduled after # the read) self.mark_gather(src) return result
[docs] def mark_gather(self, src): """Marks ``src`` as being gathered, but doesn't actually perform a gather. Used to indicate that some computation relies on ``src``. Parameters ---------- src : :class:`.TensorSignal` Signal indicating the data being read """ self.gather_bases += [self.bases[src.key]]
[docs] def combine(self, sigs, label="Combine"): """Combines several TensorSignals into one by concatenating along the first axis. Parameters ---------- sigs : list of :class:`.TensorSignal` or \ :class:`~nengo:nengo.builder.Signal` Signals to be combined label : str, optional Name for combined signal (to help with debugging) Returns ------- :class:`.TensorSignal` New TensorSignal representing the concatenation of the data in ``sigs`` """ if len(sigs) == 0: return [] assert isinstance(sigs, (list, tuple)) assert isinstance(sigs[0], (Signal, TensorSignal)) sigs = [self[s] if isinstance(s, Signal) else s for s in sigs] # make sure all the signals have the same base # note: this also tells us that they have the same dtype and # minibatching key = sigs[0].key assert all(s.key == key for s in sigs) # make sure all signals have the same shape (except first axis, # which we're concatenating along); note, this can fail even if they # all have the same base, due to reshaping shape = (np.sum([s.shape[0] for s in sigs]),) + sigs[0].shape[1:] assert all(s.shape[1:] == shape[1:] for s in sigs) indices = np.concatenate([s.indices for s in sigs], axis=0) output = self.get_tensor_signal(indices, key, sigs[0].dtype, shape, sigs[0].minibatched, label=label) return output
def make_internal(self, name, shape, minibatched=True): sig = self.get_tensor_signal( np.arange(shape[0]), object(), self.dtype, shape, minibatched, label=name) with tf.variable_scope(tf.get_default_graph().get_name_scope(), reuse=False): var = tf.get_local_variable( name, shape=sig.full_shape, dtype=sig.dtype, trainable=False, initializer=tf.zeros_initializer()) self.internal_vars[sig.key] = var return sig
[docs] def get_tensor_signal(self, indices, key, dtype, shape, minibatched, signal=None, label="TensorSignal"): """ Creates a new ``TensorSignal`` with the given properties. This should be used rather than instantiating a new TensorSignal directly, as it handles some extra book-keeping (e.g., using the custom :meth:`.constant` function). Parameters ---------- indices : tuple or list or :class:`~numpy:numpy.ndarray` of int Indices along the first axis of the base array corresponding to the data for this signal key : object Key mapping to the base array that contains the data for this signal dtype : :class:`~numpy:numpy.dtype` dtype of the values represented by this signal shape : tuple of int View shape of this signal (may differ from shape of base array) minibatched : bool Whether or not this signal contains a minibatch dimension signal : :class:`~nengo:nengo.builder.Signal`, optional If not None, associate the new ``TensorSignal`` with the given ``Signal`` in the ``sig_map`` label : str, optional Name for this signal, used to make debugging easier Returns ------- :class:`.TensorSignal` A new ``TensorSignal`` with the given properties """ tensor_sig = TensorSignal( indices, key, dtype, shape, self.minibatch_size if minibatched else None, self.constant, label=label) if signal is not None: assert len(indices) == (1 if len(signal.shape) == 0 else signal.shape[0]) assert signal.size == np.prod(shape) assert signal.minibatched == minibatched self[signal] = tensor_sig return tensor_sig
[docs] def constant(self, value, dtype=None, cutoff=1 << 25): """ Returns a constant Tensor containing the given value. The returned Tensor may be underpinned by a ``tf.constant`` op, or a ``tf.Variable`` that will be initialized to the constant value. We use the latter in order to avoid storing large constant values in the TensorFlow GraphDef, which has a hard-coded limit of 2GB at the moment. Parameters ---------- value : :class:`~numpy:numpy.ndarray` Array containing the value of the constant dtype : ``tf.DType``, optional The type for the constant (if ``None``, the dtype of ``value`` will be used) cutoff : int, optional The size of constant (in bytes) for which we will switch from ``tf.constant`` to ``tf.Variable`` Returns ------- ``tf.Tensor`` A tensor representing the given value """ value = np.asarray(value) if dtype is None: dtype = value.dtype dtype = tf.as_dtype(dtype) if value.nbytes > cutoff: def make_ph(shape, dtype, **_): ph = tf.placeholder(dtype, shape) self.constant_phs[ph] = value return ph with tf.variable_scope("constant_vars", reuse=False): # tensorflow doesn't support int32 variables on the gpu, only # int64 (for some reason). we don't want to use int64 since # that would increase the size a lot, so we allow the variable # to be created on the CPU if necessary, and then move it to # the GPU with the identity # TODO: double check if this is still true in 1.9.0 with tf.device(None): const_var = tf.get_variable( "constant_%d" % len(self.constant_phs), initializer=make_ph, shape=value.shape, dtype=dtype, collections=["constants"], trainable=False) return tf.identity(const_var) else: return tf.constant(value, dtype=dtype)
[docs] def op_constant(self, ops, op_sizes, attr, dtype, ndims=2): """ Creates a tensor representing the constant parameters of an op group. Parameters ---------- ops : list of object The operators for some merged group of ops op_sizes : list of int The number of constant elements in each op attr : str The attribute of the op that describes the constant parameter dtype : ``tf.DType`` Numeric type of the parameter ndims : int Empty dimensions will be added to the end of the returned tensor for all ndims > 1 (in the case that it is not a scalar). Returns ------- ``tf.Tensor`` Tensor containing the values of ``attr`` for the given ops. This will be a scalar if all the ops have the same parameter value, or an array giving the parameter value for each element in each op. """ val0 = getattr(ops[0], attr) if np.allclose([getattr(op, attr) for op in ops], val0): return tf.constant(val0, dtype=dtype) return self.constant( [np.reshape(getattr(op, attr), [1] * (ndims - 1)) for i, op in enumerate(ops) for _ in range(op_sizes[i])], dtype=dtype)
def __getitem__(self, sig): return self.sig_map[sig] def __setitem__(self, sig, tensor_sig): self.sig_map[sig] = tensor_sig def __len__(self): return len(self.sig_map) def __iter__(self): return iter(self.sig_map) def __contains__(self, sig): return sig in self.sig_map