from distutils.version import LooseVersion
import logging
import warnings
from nengo.builder.processes import SimProcess
from nengo.exceptions import SimulationError
from nengo.synapses import Lowpass, LinearFilter
from nengo.utils.filter_design import (cont2discrete, tf2ss, ss2tf,
BadCoefficients)
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import gen_sparse_ops
from nengo_dl import utils
from nengo_dl.builder import Builder, OpBuilder
logger = logging.getLogger(__name__)
[docs]@Builder.register(SimProcess)
class SimProcessBuilder(OpBuilder):
"""Builds a group of :class:`~nengo:nengo.builder.processes.SimProcess`
operators.
Calls the appropriate sub-build class for the different process types.
Attributes
----------
TF_PROCESS_IMPL : list of :class:`~nengo:nengo.Process`
The process types that have a custom implementation
"""
TF_PROCESS_IMPL = (Lowpass, LinearFilter)
def __init__(self, ops, signals):
super(SimProcessBuilder, self).__init__(ops, signals)
logger.debug("process %s", [op.process for op in ops])
logger.debug("input %s", [op.input for op in ops])
logger.debug("output %s", [op.output for op in ops])
logger.debug("t %s", [op.t for op in ops])
# if we have a custom tensorflow implementation for this process type,
# then we build that. otherwise we'll execute the process step
# function externally (using `tf.py_func`), so we just need to set up
# the inputs/outputs for that.
if isinstance(ops[0].process, self.TF_PROCESS_IMPL):
# note: we do this two-step check (even though it's redundant) to
# make sure that TF_PROCESS_IMPL is kept up to date
if type(ops[0].process) == Lowpass:
self.built_process = LowpassBuilder(ops, signals)
elif isinstance(ops[0].process, LinearFilter):
self.built_process = LinearFilterBuilder(ops, signals)
else:
self.built_process = GenericProcessBuilder(ops, signals)
[docs] def build_step(self, signals):
self.built_process.build_step(signals)
[docs] def build_post(self, ops, signals, sess, rng):
if isinstance(self.built_process, GenericProcessBuilder):
self.built_process.build_post(ops, signals, sess, rng)
[docs]class GenericProcessBuilder(OpBuilder):
"""Builds all process types for which there is no custom TensorFlow
implementation.
Notes
-----
These will be executed as native Python functions, requiring execution to
move in and out of TensorFlow. This can significantly slow down the
simulation, so any performance-critical processes should consider
adding a custom TensorFlow implementation for their type instead.
"""
def __init__(self, ops, signals):
super(GenericProcessBuilder, self).__init__(ops, signals)
self.input_data = (None if ops[0].input is None else
signals.combine([op.input for op in ops]))
self.output_data = signals.combine([op.output for op in ops])
self.output_shape = self.output_data.shape + (signals.minibatch_size,)
self.mode = "inc" if ops[0].mode == "inc" else "update"
self.prev_result = []
# build the step function for each process
self.step_fs = [[None for _ in range(signals.minibatch_size)]
for _ in ops]
# `merged_func` calls the step function for each process and
# combines the result
@utils.align_func(self.output_shape, self.output_data.dtype)
def merged_func(time, input): # pragma: no cover
if any(x is None for a in self.step_fs for x in a):
raise SimulationError(
"build_post has not been called for %s" % self)
input_offset = 0
func_output = []
for i, op in enumerate(ops):
if op.input is not None:
input_shape = op.input.shape[0]
func_input = input[input_offset:input_offset + input_shape]
input_offset += input_shape
mini_out = []
for j in range(signals.minibatch_size):
x = [] if op.input is None else [func_input[..., j]]
mini_out += [self.step_fs[i][j](*([time] + x))]
func_output += [np.stack(mini_out, axis=-1)]
return np.concatenate(func_output, axis=0)
self.merged_func = merged_func
self.merged_func.__name__ = utils.sanitize_name(
"_".join([type(op.process).__name__ for op in ops]))
[docs] def build_step(self, signals):
input = ([] if self.input_data is None
else signals.gather(self.input_data))
# note: we need to make sure that the previous call to this function
# has completed before the next starts, since we don't know that the
# functions are thread safe
with tf.control_dependencies(self.prev_result), tf.device("/cpu:0"):
result = tf.py_func(
self.merged_func, [signals.time, input],
self.output_data.dtype, name=self.merged_func.__name__)
result.set_shape(self.output_shape)
self.prev_result = [result]
signals.scatter(self.output_data, result, mode=self.mode)
[docs] def build_post(self, ops, signals, sess, rng):
for i, op in enumerate(ops):
for j in range(signals.minibatch_size):
self.step_fs[i][j] = op.process.make_step(
op.input.shape if op.input is not None else (0,),
op.output.shape, signals.dt_val, op.process.get_rng(rng))
[docs]class LowpassBuilder(OpBuilder):
"""Build a group of :class:`~nengo:nengo.Lowpass` synapse operators."""
def __init__(self, ops, signals):
super(LowpassBuilder, self).__init__(ops, signals)
self.input_data = signals.combine([op.input for op in ops])
self.output_data = signals.combine([op.output for op in ops])
nums = []
dens = []
for op in ops:
if op.process.tau <= 0.03 * signals.dt_val:
num = 1
den = 0
else:
num, den, _ = cont2discrete((op.process.num, op.process.den),
signals.dt_val, method="zoh")
num = num.flatten()
num = num[1:] if num[0] == 0 else num
assert len(num) == 1
num = num[0]
assert len(den) == 2
den = den[1]
nums += [num] * op.input.shape[0]
dens += [den] * op.input.shape[0]
nums = np.asarray(nums)
while nums.ndim < len(self.input_data.full_shape):
nums = np.expand_dims(nums, -1)
# note: applying the negative here
dens = -np.asarray(dens)
while dens.ndim < len(self.input_data.full_shape):
dens = np.expand_dims(dens, -1)
# need to manually broadcast for scatter_mul
# dens = np.tile(dens, (1, signals.minibatch_size))
self.nums = signals.constant(nums, dtype=self.output_data.dtype)
self.dens = signals.constant(dens, dtype=self.output_data.dtype)
# create a variable to represent the internal state of the filter
# self.state_sig = signals.make_internal(
# "state", self.output_data.shape)
[docs] def build_step(self, signals):
# signals.scatter(self.output_data, self.dens, mode="mul")
# input = signals.gather(self.input_data)
# signals.scatter(self.output_data, self.nums * input, mode="inc")
input = signals.gather(self.input_data)
output = signals.gather(self.output_data)
signals.scatter(self.output_data,
self.dens * output + self.nums * input)
# method using _step
# note: this build_step function doesn't use _step for efficiency
# reasons (we can avoid an extra scatter by reusing the output signal
# as the state signal)
# input = signals.gather(self.input_data)
# prev_state = signals.gather(self.state_sig)
# new_state = self.dens * prev_state + self.nums * input
# signals.scatter(self.state_sig, new_state)
# signals.scatter(self.output_data, new_state)
[docs]class LinearFilterBuilder(OpBuilder):
"""Build a group of :class:`~nengo:nengo.LinearFilter` synapse
operators."""
def __init__(self, ops, signals):
super(LinearFilterBuilder, self).__init__(ops, signals)
self.input_data = signals.combine([op.input for op in ops])
self.output_data = signals.combine([op.output for op in ops])
self.n_ops = len(ops)
self.signal_d = ops[0].input.shape[0]
As = []
Cs = []
Ds = []
# compute the A/C/D matrices for each operator
for op in ops:
A, B, C, D = tf2ss(op.process.num, op.process.den)
if op.process.analog:
# convert to discrete system
A, B, C, D, _ = cont2discrete((A, B, C, D), signals.dt_val,
method="zoh")
# convert to controllable form
num, den = ss2tf(A, B, C, D)
if op.process.analog:
# add shift
num = np.concatenate((num, [[0]]), axis=1)
with warnings.catch_warnings():
# ignore the warning about B, since we aren't using it anyway
warnings.simplefilter("ignore", BadCoefficients)
A, _, C, D = tf2ss(num, den)
As.append(A)
Cs.append(C[0])
Ds.append(D.item())
self.state_d = sum(x.shape[0] for x in Cs)
# build a sparse matrix containing the A matrices as blocks
# along the diagonal
sparse_indices = []
corner = np.zeros(2, dtype=np.int64)
for A in As:
idxs = np.reshape(np.dstack(np.meshgrid(
np.arange(A.shape[0]), np.arange(A.shape[1]),
indexing="ij")), (-1, 2))
idxs += corner
corner += A.shape
sparse_indices += [idxs]
sparse_indices = np.concatenate(sparse_indices, axis=0)
self.A = signals.constant(np.concatenate(As, axis=0).flatten(),
dtype=signals.dtype)
self.A_indices = signals.constant(sparse_indices, dtype=(
tf.int32 if np.all(sparse_indices < np.iinfo(np.int32).max)
else tf.int64))
self.A_shape = tf.constant(corner, dtype=tf.int64)
if np.allclose(Cs, 0):
self.C = None
else:
# add empty dimension for broadcasting
self.C = signals.constant(np.concatenate(Cs)[:, None],
dtype=signals.dtype)
if np.allclose(Ds, 0):
self.D = None
else:
# add empty dimension for broadcasting
self.D = signals.constant(np.asarray(Ds)[:, None],
dtype=signals.dtype)
self.offsets = tf.expand_dims(
tf.range(0, len(ops) * As[0].shape[0], As[0].shape[0]),
axis=1)
# create a variable to represent the internal state of the filter
self.state_sig = signals.make_internal(
"state", (self.state_d, signals.minibatch_size * self.signal_d),
minibatched=False)
[docs] def build_step(self, signals):
input = signals.gather(self.input_data)
input = tf.reshape(input, (self.n_ops, -1))
state = signals.gather(self.state_sig)
# compute output
if self.C is None:
output = tf.zeros_like(input)
else:
output = state * self.C
output = tf.reshape(
output,
(self.n_ops, -1, signals.minibatch_size * self.signal_d))
output = tf.reduce_sum(output, axis=1)
if self.D is not None:
output += self.D * input
signals.scatter(self.output_data, output)
# update state
if LooseVersion(tf.__version__) < LooseVersion("1.7.0"):
mat_mul = gen_sparse_ops._sparse_tensor_dense_mat_mul
else:
mat_mul = gen_sparse_ops.sparse_tensor_dense_mat_mul
r = mat_mul(self.A_indices, self.A, self.A_shape, state)
with tf.control_dependencies([output]):
state = r + tf.scatter_nd(self.offsets, input,
self.state_sig.shape)
# TODO: tensorflow does not yet support sparse_tensor_dense_add
# on the GPU
# state = gen_sparse_ops._sparse_tensor_dense_add(
# self.offsets, input, self.state_sig.shape, r)
state.set_shape(self.state_sig.shape)
signals.mark_gather(self.input_data)
signals.mark_gather(self.state_sig)
signals.scatter(self.state_sig, state)