"""
Build classes for basic Nengo operators.
"""
from collections import defaultdict
import logging
import warnings
from nengo.builder.operator import (
Reset,
Copy,
ElementwiseInc,
DotInc,
SimPyFunc,
TimeUpdate,
)
from nengo.builder.transforms import SparseDotInc
from nengo.transforms import SparseMatrix
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import gen_sparse_ops
from nengo_dl import utils
from nengo_dl.builder import Builder, OpBuilder
logger = logging.getLogger(__name__)
[docs]class ResetInc(Reset):
"""
A version of `~nengo.builder.operator.Reset` that increments the target value
rather than overwriting.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.incs, self.sets = self.sets, self.incs
@property
def dst(self):
"""dst is stored in ``incs`` rather than ``sets``."""
return self.incs[0]
[docs]class ElementwiseSet(ElementwiseInc):
"""
A version of `~nengo.builder.operator.ElementwiseInc` that overwrites the target
rather than incrementing.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.incs, self.sets = self.sets, self.incs
@property
def Y(self):
"""Y is stored in ``sets`` rather than ``incs``."""
return self.sets[0]
[docs]class DotSet(DotInc):
"""
A version of `~nengo.builder.operator.DotInc` that overwrites the target rather
than incrementing.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.incs, self.sets = self.sets, self.incs
@property
def Y(self):
"""Y is stored in ``sets`` rather than ``incs``."""
return self.sets[0]
[docs]class SparseDotSet(SparseDotInc):
"""
A version of `~nengo.builder.operator.SparseDotInc` that overwrites the target
rather than incrementing.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.incs, self.sets = self.sets, self.incs
@property
def Y(self):
"""Y is stored in ``sets`` rather than ``incs``."""
return self.sets[0]
[docs]@Builder.register(Reset)
@Builder.register(ResetInc)
class ResetBuilder(OpBuilder):
"""
Build a group of `~nengo.builder.operator.Reset` operators.
"""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
logger.debug("val %s", [op.value for op in self.ops])
logger.debug("dst %s", [op.dst for op in self.ops])
self.mode = "inc" if type(self.ops[0]) == ResetInc else "update"
dtype = np.asarray(self.ops[0].value).dtype
if np.issubdtype(dtype, np.floating):
dtype = signals.dtype.as_numpy_dtype
# Reset signals might be spread across multiple bases, so group them
# by the ones that do share a base
scatters = defaultdict(list)
for op in self.ops:
scatters[signals[op.dst].key].append(op)
self.scatters = []
for group in scatters.values():
value = tf.concat(
[
tf.broadcast_to(
tf.cast(x.value, dtype),
(signals.minibatch_size,) + x.dst.shape,
)
for x in group
],
axis=1,
)
self.scatters.append((signals.combine([x.dst for x in group]), value))
logger.debug("scatters")
logger.debug("\n".join([str(x) for x in self.scatters]))
[docs] def build_step(self, signals):
for data, val in self.scatters:
signals.scatter(data, val, mode=self.mode)
[docs] @staticmethod
def mergeable(x, y):
return True
[docs]@Builder.register(Copy)
class CopyBuilder(OpBuilder):
"""
Build a group of `~nengo.builder.operator.Copy` operators.
"""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
logger.debug("src %s", [op.src for op in self.ops])
logger.debug(
"src_slice %s", [getattr(op, "src_slice", None) for op in self.ops]
)
logger.debug("dst %s", [op.dst for op in self.ops])
logger.debug(
"dst_slice %s", [getattr(op, "dst_slice", None) for op in self.ops]
)
self.src_data = signals.combine(
[signals[op.src][op.src_slice] for op in self.ops]
)
self.dst_data = signals.combine(
[signals[op.dst][op.dst_slice] for op in self.ops]
)
self.mode = "inc" if self.ops[0].inc else "update"
[docs] def build_step(self, signals):
src = signals.gather(self.src_data)
if not self.src_data.minibatched and self.dst_data.minibatched:
src = tf.broadcast_to(src, self.dst_data.full_shape)
signals.scatter(self.dst_data, src, mode=self.mode)
[docs] @staticmethod
def mergeable(x, y):
return True
[docs]@Builder.register(ElementwiseInc)
@Builder.register(ElementwiseSet)
class ElementwiseIncBuilder(OpBuilder):
"""
Build a group of `~nengo.builder.operator.ElementwiseInc` operators.
"""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
logger.debug("dst %s", [op.Y for op in self.ops])
logger.debug("A %s", [op.A for op in self.ops])
logger.debug("X %s", [op.X for op in self.ops])
self.mode = "inc" if type(self.ops[0]) == ElementwiseInc else "update"
self.Y_data = signals.combine([op.Y for op in self.ops])
# group all the A's and X's
self.A_data = signals.combine([op.A for op in self.ops])
self.X_data = signals.combine([op.X for op in self.ops])
# separate data from each op along the first dimension
# (we only need to do this if they don't have the same length already)
if self.A_data.shape[0] != self.X_data.shape[0]:
self.A_data = self.A_data.reshape(
(len(self.ops), -1) + self.A_data.shape[1:]
)
self.X_data = self.X_data.reshape(
(len(self.ops), -1) + self.X_data.shape[1:]
)
# add empty trailing dimensions for elementwise broadcasting
while self.A_data.ndim < self.X_data.ndim:
self.A_data = self.A_data.reshape(self.A_data.shape + (1,))
# add broadcast dimension for minibatch, if needed
if not self.A_data.minibatched and self.X_data.minibatched:
self.A_data = self.A_data.reshape((1,) + self.A_data.shape)
[docs] def build_step(self, signals):
A = signals.gather(self.A_data)
X = signals.gather(self.X_data)
result = tf.multiply(A, X)
signals.scatter(self.Y_data, result, mode=self.mode)
[docs] @staticmethod
def mergeable(x, y):
# for these operations we enforce that the first dimensions
# match (we know all the other dimensions match due to the generic
# checks).
# this allows us to stack all the arguments into continuous array
# blocks, allowing for more efficient multiplication (mainly
# because it allows us to take advantage of broadcasting)
for s0, s1 in zip(x.all_signals, y.all_signals):
shape0 = s0.shape[0] if s0.shape != () else 1
shape1 = s1.shape[0] if s1.shape != () else 1
if shape0 != shape1:
return False
return True
[docs]def sparse_matmul(A_indices, A_data, A_shape, X, transpose_x=False):
"""
Matrix multiplication between sparse matrix A and dense matrix X
Parameters
----------
A_indices : ``tf.Tensor``
(N, 2) array of [row,col] non-zero entries
A_data : ``tf.Tensor``
(N,) array of data in the nonzero entries specified in ``A_indices``
A_shape : tuple of int
Shape of full A matrix
X : ``tf.Tensor``
Dense matrix being multiplied by A
transpose_x : bool
Transpose X before multiply
Returns
-------
dot : ``tf.Tensor``
Result of matrix multiplication between A and X
"""
must_downcast = A_data.dtype.base_dtype != tf.float32 and (
"gpu" in A_data.device.lower()
or (A_data.device == "" and utils.tf_gpu_installed)
)
if must_downcast:
assert A_data.dtype.base_dtype == X.dtype.base_dtype
warnings.warn(
"Downcasting data to float32 in sparse_matmul, since "
"only float32 is supported on the GPU."
)
A = tf.cast(A_data, tf.float32)
X = tf.cast(X, tf.float32)
else:
A = A_data
dot = gen_sparse_ops.sparse_tensor_dense_mat_mul(
A_indices, A, A_shape, X, adjoint_b=transpose_x
)
if must_downcast:
dot = tf.cast(dot, A_data.dtype.base_dtype)
return dot
[docs]@Builder.register(DotInc)
@Builder.register(DotSet)
class DotIncBuilder(OpBuilder):
"""
Build a group of `~nengo.builder.operator.DotInc` operators.
"""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
logger.debug("dst %s", [op.Y for op in self.ops])
logger.debug("A %s", [op.A for op in self.ops])
logger.debug("X %s", [op.X for op in self.ops])
self.mode = "inc" if type(self.ops[0]) == DotInc else "update"
self.Y_data = signals.combine([op.Y for op in self.ops])
# group all the A's and X's
A_data = signals.combine([op.A for op in self.ops])
X_data = signals.combine([op.X for op in self.ops])
# separate data from each op along the first dimension
self.A_data = A_data.reshape((len(self.ops), -1, A_data.shape[1]))
self.X_data = X_data.reshape((len(self.ops), -1))
if self.A_data.minibatched:
# change X to matrix
self.X_data = self.X_data.reshape(self.X_data.shape + (1,))
else:
# precompute transposition permutation
self.perm = tf.constant((1, 2, 0))
self.perm_inv = tf.constant((2, 0, 1))
[docs] def build_step(self, signals):
A = signals.gather(self.A_data)
X = signals.gather(self.X_data)
if self.A_data.minibatched and self.X_data.minibatched:
# (batch, n_ops, a0, a1) x (batch, n_ops, a1, 1)
dot = tf.matmul(A, X)
elif not self.A_data.minibatched and self.X_data.minibatched:
# (n_ops, a0, a1) x (batch, n_ops, a1)
# -> (n_ops, a0, a1) x (n_ops, a1, batch)
dot = tf.matmul(A, tf.transpose(X, perm=self.perm))
# transpose back to (batch, n_ops, a0)
dot = tf.transpose(dot, perm=self.perm_inv)
# for some reason the transposing causes TensorFlow to lose track of
# the shape (only when the `perm` constants are outside the loop)
dot.set_shape((signals.minibatch_size,) + self.A_data.shape[:2])
else:
raise NotImplementedError
signals.scatter(self.Y_data, dot, mode=self.mode)
[docs] @staticmethod
def mergeable(x, y):
# the first dimensions need to match up (to allow us to separate them by op)
for s0, s1 in zip(x.all_signals, y.all_signals):
shape0 = s0.shape[0] if s0.shape != () else 1
shape1 = s1.shape[0] if s1.shape != () else 1
if shape0 != shape1:
return False
return True
[docs]@Builder.register(SimPyFunc)
class SimPyFuncBuilder(OpBuilder):
"""
Build a group of `~nengo.builder.operator.SimPyFunc` operators.
"""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
logger.debug("t %s", [op.t for op in self.ops])
logger.debug("x %s", [op.x for op in self.ops])
logger.debug("fn %s", [op.fn for op in self.ops])
self.time_data = (
None if self.ops[0].t is None else signals[self.ops[0].t].reshape(())
)
self.input_data = signals.combine([op.x for op in self.ops])
if self.ops[0].output is not None:
self.output_data = signals.combine([op.output for op in self.ops])
self.output_dtype = self.output_data.dtype
else:
self.output_data = None
self.output_dtype = signals.dtype
def merged_func(time, inputs): # pragma: no cover (runs in TF)
outputs = []
offset = 0
for op in self.ops:
if op.output is None:
func = op.fn
else:
func = utils.align_func(self.output_dtype)(op.fn)
func_input = inputs[:, offset : offset + op.x.shape[0]]
offset += op.x.shape[0]
mini_out = []
for j in range(signals.minibatch_size):
if op.t is None:
func_out = func(func_input[j])
else:
func_out = func(time, func_input[j])
func_out = np.atleast_1d(func_out)
if op.output is None:
# just return time as a noop (since we need to
# return something)
func_out = [time]
mini_out += [func_out]
outputs += [np.stack(mini_out, axis=0)]
return np.concatenate(outputs, axis=1)
self.merged_func = merged_func
self.merged_func.__name__ = "_".join(
[utils.function_name(op.fn) for op in self.ops]
)
self.output_shape = (signals.minibatch_size,)
self.output_shape += (
(len(self.ops),) if self.output_data is None else self.output_data.shape
)
[docs] def build_step(self, signals):
time = [] if self.time_data is None else signals.gather(self.time_data)
inputs = [] if self.input_data is None else signals.gather(self.input_data)
node_outputs = tf.numpy_function(
self.merged_func,
[time, inputs],
self.output_dtype,
name=self.merged_func.__name__,
)
node_outputs.set_shape(self.output_shape)
if self.output_data is not None:
signals.scatter(self.output_data, node_outputs)
# note: we only need to run the node for side effects, not the
# assignment operator. if the result of the assignment is actually
# used anywhere, then it will be run as part of the normal graph.
return node_outputs
[docs] @staticmethod
def mergeable(x, y):
# for these we need to make a special check that the functions
# all do/do not get time as input, otherwise we could end
# up confusing a node that only gets a scalar float input with
# a node that only gets time as input
return x.t == y.t
[docs]@Builder.register(SparseDotInc)
@Builder.register(SparseDotSet)
class SparseDotIncBuilder(OpBuilder):
"""
Build a group of `~nengo.builder.operator.SparseDotInc` operators.
"""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
self.mode = "inc" if type(self.ops[0]) == SparseDotInc else "update"
self.Y_data = signals.combine([op.Y for op in self.ops])
# group all the A's and X's
self.A_data = signals.combine([op.A for op in self.ops])
self.X_data = signals.combine([op.X for op in self.ops])
# the only way A would be minibatched is if it is targeted by an
# online learning rule, which isn't supported for sparse transforms
assert not self.A_data.minibatched
assert self.X_data.minibatched and self.Y_data.minibatched
# arrange the sparse matrices into a (sparse) block diagonal matrix
# by adding an offset to each sparse matrix's indices
sparse_indices = []
corner = np.zeros(2, dtype=np.int64)
for op in self.ops:
if isinstance(op.A.initial_value, SparseMatrix):
idxs = np.array(op.A.initial_value.indices)
else:
initial_value = op.A.initial_value.tocoo()
idxs = np.stack((initial_value.row, initial_value.col), axis=1)
block_shape = (op.A.shape[0], op.A.shape[1])
idxs += corner
corner += block_shape
sparse_indices += [idxs]
sparse_indices = np.concatenate(sparse_indices, axis=0)
self.sparse_indices = tf.constant(
sparse_indices,
dtype=(
tf.int32
if np.all(sparse_indices < np.iinfo(np.int32).max)
else tf.int64
),
)
self.A_shape = tf.constant(corner, dtype=tf.int64)
self.perm = tf.constant((1, 0))
[docs] def build_step(self, signals):
A = signals.gather(self.A_data)
X = signals.gather(self.X_data)
# (sum(a0s), sum(a1s)) x (batch, sum(a1s))
# -> (sum(a0s), sum(a1s)) x (sum(a1s), batch)
dot = sparse_matmul(self.sparse_indices, A, self.A_shape, X, transpose_x=True)
# transpose result back to (batch, sum(a0s))
dot = tf.transpose(dot, perm=self.perm)
dot.set_shape((signals.minibatch_size,) + self.Y_data.shape)
signals.scatter(self.Y_data, dot, mode=self.mode)
[docs] @staticmethod
def mergeable(x, y):
return True
[docs]@Builder.register(TimeUpdate)
class TimeUpdateBuilder(OpBuilder):
"""
Build a group of `~nengo.builder.operator.TimeUpdate` operators.
"""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
assert len(self.ops) == 1
op = self.ops[0]
self.step_data = signals[op.step]
self.time_data = signals[op.time]
self.one = tf.constant(1, dtype=tf.int32)
[docs] def build_step(self, signals):
step = signals.gather(self.step_data)
step += self.one
signals.scatter(self.step_data, step)
signals.scatter(self.time_data, tf.cast(step, signals.dtype) * signals.dt)
[docs] @staticmethod
def mergeable(x, y):
# there should only ever be one TimeUpdate so this should never be called
raise NotImplementedError