Source code for nengo_dl.process_builders

"""
Build classes for Nengo process operators.
"""

from collections import OrderedDict
import logging

from nengo.builder.processes import SimProcess
from nengo.exceptions import SimulationError
from nengo.synapses import Lowpass, LinearFilter
from nengo.utils.filter_design import cont2discrete
import numpy as np
import tensorflow as tf

from nengo_dl import utils
from nengo_dl.builder import Builder, OpBuilder


logger = logging.getLogger(__name__)


[docs]class GenericProcessBuilder(OpBuilder): """ Builds all process types for which there is no custom TensorFlow implementation. Notes ----- These will be executed as native Python functions, requiring execution to move in and out of TensorFlow. This can significantly slow down the simulation, so any performance-critical processes should consider adding a custom TensorFlow implementation for their type instead. """ def __init__(self, ops, signals, config): super().__init__(ops, signals, config) self.time_data = signals[ops[0].t].reshape(()) self.input_data = ( None if ops[0].input is None else signals.combine([op.input for op in ops]) ) self.output_data = signals.combine([op.output for op in ops]) self.state_data = [ signals.combine([list(op.state.values())[i] for op in ops]) for i in range(len(ops[0].state)) ] self.mode = "inc" if ops[0].mode == "inc" else "update" self.prev_result = [] # build the step function for each process self.step_fs = [[None for _ in range(signals.minibatch_size)] for _ in ops] # `merged_func` calls the step function for each process and # combines the result @utils.align_func( [self.output_data.full_shape] + [s.full_shape for s in self.state_data], [self.output_data.dtype] + [s.dtype for s in self.state_data], ) def merged_func(time, *input_state): # pragma: no cover (runs in TF) if any(x is None for a in self.step_fs for x in a): raise SimulationError("build_post has not been called for %s" % self) if self.input_data is None: input = None state = input_state else: input = input_state[0] state = input_state[1:] # update state in-place (this will update the state values # inside step_fs) for i, s in enumerate(state): self.step_states[i][...] = s input_offset = 0 func_output = [] for i, op in enumerate(ops): if op.input is not None: input_shape = op.input.shape[0] func_input = input[:, input_offset : input_offset + input_shape] input_offset += input_shape mini_out = [] for j in range(signals.minibatch_size): x = [] if op.input is None else [func_input[j]] mini_out += [self.step_fs[i][j](*([time] + x))] func_output += [np.stack(mini_out, axis=0)] return [np.concatenate(func_output, axis=1)] + self.step_states self.merged_func = merged_func self.merged_func.__name__ = utils.sanitize_name( "_".join([type(op.process).__name__ for op in ops]) )
[docs] def build_step(self, signals): time = [signals.gather(self.time_data)] input = [] if self.input_data is None else [signals.gather(self.input_data)] state = [signals.gather(s) for s in self.state_data] # note: we need to make sure that the previous call to this function # has completed before the next starts, since we don't know that the # functions are thread safe # TODO: this isn't necessary in eager mode with tf.control_dependencies(self.prev_result), tf.device("/cpu:0"): result = tf.numpy_function( self.merged_func, time + input + state, [self.output_data.dtype] + [s.dtype for s in self.state_data], name=self.merged_func.__name__, ) output = result[0] state = result[1:] self.prev_result = [output] output.set_shape(self.output_data.full_shape) signals.scatter(self.output_data, output, mode=self.mode) for i, s in enumerate(state): s.set_shape(self.state_data[i].full_shape) signals.scatter(self.state_data[i], s, mode="update")
[docs] def build_post(self, ops, signals, config): # generate state for each op step_states = [ op.process.make_state( op.input.shape if op.input is not None else (0,), op.output.shape, signals.dt_val, ) for op in ops ] # build all the states into combined array with shape # (n_states, n_ops, *state_d) combined_states = [[None for _ in ops] for _ in range(len(ops[0].state))] for i, op in enumerate(ops): # note: we iterate over op.state so that the order is always based on that # dict's order (which is what we used to set up self.state_data) for j, name in enumerate(op.state): combined_states[j][i] = step_states[i][name] # combine op states, giving shape # (n_states, n_ops * state_d[0], *state_d[1:]) # (keeping track of the offset of where each op's state lies in the # combined array) offsets = [[s.shape[0] for s in state] for state in combined_states] offsets = np.cumsum(offsets, axis=-1) self.step_states = [np.concatenate(state, axis=0) for state in combined_states] # duplicate state for each minibatch, giving shape # (n_states, minibatch_size, n_ops * state_d[0], *state_d[1:]) assert all(s.minibatched for op in ops for s in op.state.values()) for i, state in enumerate(self.step_states): self.step_states[i] = np.tile( state[None, ...], (signals.minibatch_size,) + (1,) * state.ndim ) # build the step functions for i, op in enumerate(ops): for j in range(signals.minibatch_size): # pass each make_step function a view into the combined state state = {} for k, name in enumerate(op.state): start = 0 if i == 0 else offsets[k][i - 1] stop = offsets[k][i] state[name] = self.step_states[k][j, start:stop] assert np.allclose(state[name], step_states[i][name]) self.step_fs[i][j] = op.process.make_step( op.input.shape if op.input is not None else (0,), op.output.shape, signals.dt_val, op.process.get_rng(config.rng), state, )
[docs]class LowpassBuilder(OpBuilder): """Build a group of `~nengo.Lowpass` synapse operators.""" def __init__(self, ops, signals, config): super().__init__(ops, signals, config) # the main difference between this and the general linearfilter # OneX implementation is that this version allows us to merge # synapses with different input dimensionality (by duplicating # the synapse parameters for every input, rather than using # broadcasting) self.input_data = signals.combine([op.input for op in ops]) self.output_data = signals.combine([op.output for op in ops]) nums = [] dens = [] for op in ops: if op.process.tau <= 0.03 * signals.dt_val: num = 1 den = 0 else: num, den, _ = cont2discrete( (op.process.num, op.process.den), signals.dt_val, method="zoh" ) num = num.flatten() num = num[1:] if num[0] == 0 else num assert len(num) == 1 num = num[0] assert len(den) == 2 den = den[1] nums += [num] * op.input.shape[0] dens += [den] * op.input.shape[0] if self.input_data.minibatched: # add batch dimension for broadcasting nums = np.expand_dims(nums, 0) dens = np.expand_dims(dens, 0) # apply the negative here dens = -np.asarray(dens) self.nums = tf.constant(nums, dtype=self.output_data.dtype) self.dens = tf.constant(dens, dtype=self.output_data.dtype) # create a variable to represent the internal state of the filter # self.state_sig = signals.make_internal( # "state", self.output_data.shape)
[docs] def build_step(self, signals): input = signals.gather(self.input_data) output = signals.gather(self.output_data) signals.scatter(self.output_data, self.dens * output + self.nums * input)
# method using internal state signal # note: this isn't used for efficiency reasons (we can avoid an extra # scatter by reusing the output signal as the state signal) # input = signals.gather(self.input_data) # prev_state = signals.gather(self.state_sig) # new_state = self.dens * prev_state + self.nums * input # signals.scatter(self.state_sig, new_state) # signals.scatter(self.output_data, new_state)
[docs]class LinearFilterBuilder(OpBuilder): """Build a group of `~nengo.LinearFilter` synapse operators.""" def __init__(self, ops, signals, config): super().__init__(ops, signals, config) # note: linear filters are linear systems with n_inputs/n_outputs == 1. # we apply them to multidimensional inputs, but we do so by # broadcasting that SISO linear system (so it's effectively # d 1-dimensional linear systems). this means that we can make # some simplifying assumptions, namely that B has shape (state_d, 1), # C has shape (1, state_d), and D has shape (1, 1), and then we can # implement those operations as (broadcasted) multiplies rather than # full matrix multiplications. # this also means that the minibatch dimension is identical to the # signal dimension (i.e. n m-dimensional signals is the same as # 1 n*m-dimensional signal); in either case we're just doing that # B/C/D broadcasting along all the non-state dimensions. so in these # computations we collapse minibatch and signal dimensions into one. self.input_data = signals.combine([op.input for op in ops]) self.output_data = signals.combine([op.output for op in ops]) if self.input_data.ndim != 1: raise NotImplementedError( "LinearFilter of non-vector signals is not implemented" ) steps = [ op.process.make_step( op.input.shape, op.output.shape, signals.dt_val, state=op.process.make_state( op.input.shape, op.output.shape, signals.dt_val ), rng=None, ) for op in ops ] self.step_type = type(steps[0]) assert all(type(step) == self.step_type for step in steps) self.n_ops = len(ops) self.signal_d = ops[0].input.shape[0] self.state_d = steps[0].A.shape[0] if self.step_type == LinearFilter.NoX: self.A = None self.B = None self.C = None # combine D scalars for each op, and broadcast along minibatch and # signal dimensions self.D = tf.constant( np.concatenate([step.D[None, :, None] for step in steps], axis=1), dtype=signals.dtype, ) assert self.D.shape == (1, self.n_ops, 1) elif self.step_type == LinearFilter.OneX: # combine A scalars for each op, and broadcast along batch/state self.A = tf.constant( np.concatenate([step.A for step in steps])[None, :], dtype=signals.dtype ) # combine B and C scalars for each op, and broadcast along batch/state self.B = tf.constant( np.concatenate([step.B * step.C for step in steps])[None, :], dtype=signals.dtype, ) self.C = None self.D = None assert self.A.shape == (1, self.n_ops, 1) assert self.B.shape == (1, self.n_ops, 1) else: self.A = tf.constant( np.stack([step.A for step in steps], axis=0), dtype=signals.dtype ) self.B = tf.constant( np.stack([step.B for step in steps], axis=0), dtype=signals.dtype ) self.C = tf.constant( np.stack([step.C for step in steps], axis=0), dtype=signals.dtype ) if self.step_type == LinearFilter.NoD: self.D = None else: self.D = tf.constant( np.concatenate([step.D[:, None, None] for step in steps]), dtype=signals.dtype, ) assert self.D.shape == (self.n_ops, 1, 1) # create a variable to represent the internal state of the filter self.state_data = signals.combine([op.state["X"] for op in ops]) assert self.A.shape == (self.n_ops, self.state_d, self.state_d) assert self.B.shape == (self.n_ops, self.state_d, 1) assert self.C.shape == (self.n_ops, 1, self.state_d)
[docs] def build_step(self, signals): input = signals.gather(self.input_data) if self.step_type == LinearFilter.NoX: input = tf.reshape(input, (signals.minibatch_size, self.n_ops, -1)) signals.scatter(self.output_data, self.D * input) elif self.step_type == LinearFilter.OneX: input = tf.reshape(input, (signals.minibatch_size, self.n_ops, -1)) # note: we use the output signal in place of a separate state output = signals.gather(self.output_data) output = tf.reshape(output, (signals.minibatch_size, self.n_ops, -1)) signals.scatter(self.output_data, self.A * output + self.B * input) else: # TODO: possible to rework things to not require all the # transposing/reshaping required for moving batch to end? def undo_batch(x): x = tf.reshape(x, x.shape[:-1].as_list() + [-1, signals.minibatch_size]) x = tf.transpose(x, np.roll(np.arange(x.shape.ndims), 1)) return x # separate by op and collapse batch/state dimensions assert input.shape.ndims == 2 input = tf.transpose(input) input = tf.reshape( input, (self.n_ops, 1, self.signal_d * signals.minibatch_size) ) state = signals.gather(self.state_data) assert input.shape.ndims == 3 state = tf.transpose(state, perm=(1, 2, 0)) state = tf.reshape( state, (self.n_ops, self.state_d, self.signal_d * signals.minibatch_size), ) if self.step_type == LinearFilter.NoD: # for NoD, we update the state before computing the output new_state = tf.matmul(self.A, state) + self.B * input signals.scatter(self.state_data, undo_batch(new_state)) output = tf.matmul(self.C, new_state) signals.scatter(self.output_data, undo_batch(output)) else: # in the general case, we compute the output before updating # the state output = tf.matmul(self.C, state) if self.step_type == LinearFilter.General: output += self.D * input signals.scatter(self.output_data, undo_batch(output)) new_state = tf.matmul(self.A, state) + self.B * input signals.scatter(self.state_data, undo_batch(new_state))
[docs]@Builder.register(SimProcess) class SimProcessBuilder(OpBuilder): """ Builds a group of `~nengo.builder.processes.SimProcess` operators. Calls the appropriate sub-build class for the different process types. Attributes ---------- TF_PROCESS_IMPL : dict of {`~nengo.Process`: `.builder.OpBuilder`} Mapping from process types to custom build classes (processes without a custom builder will use the generic builder). """ # we use OrderedDict because it is important that Lowpass come before # LinearFilter (since we'll be using isinstance to find the right builder, # and Lowpass is a subclass of LinearFilter) TF_PROCESS_IMPL = OrderedDict( [(Lowpass, LowpassBuilder), (LinearFilter, LinearFilterBuilder)] ) def __init__(self, ops, signals, config): super().__init__(ops, signals, config) logger.debug("process %s", [op.process for op in ops]) logger.debug("input %s", [op.input for op in ops]) logger.debug("output %s", [op.output for op in ops]) logger.debug("t %s", [op.t for op in ops]) # if we have a custom tensorflow implementation for this process type, # then we build that. otherwise we'll execute the process step # function externally (using `tf.py_func`). for process_type, process_builder in self.TF_PROCESS_IMPL.items(): if isinstance(ops[0].process, process_type): self.built_process = process_builder(ops, signals, config) break else: self.built_process = GenericProcessBuilder(ops, signals, config)
[docs] def build_step(self, signals): self.built_process.build_step(signals)
[docs] def build_post(self, ops, signals, config): if isinstance(self.built_process, GenericProcessBuilder): self.built_process.build_post(ops, signals, config)
[docs] @staticmethod def mergeable(x, y): # we can merge ops if they have a custom implementation, or merge # generic processes, but can't mix the two custom_impl = tuple(SimProcessBuilder.TF_PROCESS_IMPL.keys()) if isinstance(x.process, custom_impl): if type(x.process) == Lowpass or type(y.process) == Lowpass: # lowpass ops can only be merged with other lowpass ops, since # they have a custom implementation if type(x.process) != type(y.process): # noqa: E721 return False elif isinstance(x.process, LinearFilter): # we can only merge linearfilters that have the same state # dimensionality (den), the same step type (num), and the same # input signal dimensionality if ( not isinstance(y.process, LinearFilter) or len(y.process.num) != len(x.process.num) or len(y.process.den) != len(x.process.den) or x.input.shape[0] != y.input.shape[0] ): return False else: raise NotImplementedError() elif isinstance(y.process, custom_impl): return False return True