"""
Build classes for Nengo neuron operators.
"""
from collections import OrderedDict
import contextlib
import logging
import warnings
from nengo.builder.neurons import SimNeurons
from nengo.neurons import RectifiedLinear, SpikingRectifiedLinear, Sigmoid, LIF, LIFRate
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.utils import tf_utils
from nengo_dl import compat, utils
from nengo_dl.builder import Builder, OpBuilder
from nengo_dl.neurons import LeakyReLU, SoftLIFRate, SpikingLeakyReLU
logger = logging.getLogger(__name__)
[docs]class GenericNeuronBuilder(OpBuilder):
"""
Builds all neuron types for which there is no custom TensorFlow
implementation.
Notes
-----
These will be executed as native Python functions, requiring execution to
move in and out of TensorFlow. This can significantly slow down the
simulation, so any performance-critical neuron models should consider
adding a custom TensorFlow implementation for their neuron type instead.
"""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
self.J_data = signals.combine([op.J for op in self.ops])
self.output_data = signals.combine([op.output for op in self.ops])
state_keys = compat.neuron_state(self.ops[0]).keys()
self.state_data = [
signals.combine([compat.neuron_state(op)[key] for op in self.ops])
for key in state_keys
]
self.prev_result = []
def neuron_step(dt, J, *states): # pragma: no cover (runs in TF)
output = None
J_offset = 0
state_offset = [0 for _ in states]
for op in self.ops:
# slice out the individual state vectors from the overall
# array
op_J = J[:, J_offset : J_offset + op.J.shape[0]]
J_offset += op.J.shape[0]
op_states = []
for j, key in enumerate(state_keys):
s = compat.neuron_state(op)[key]
op_states += [
states[j][:, state_offset[j] : state_offset[j] + s.shape[0]]
]
state_offset[j] += s.shape[0]
# call neuron step function
# note: `op_states` are views into `states`, which will
# be updated in-place
mini_out = []
for j in range(signals.minibatch_size):
# blank output variable
neuron_output = np.zeros(op.output.shape, self.output_data.dtype)
compat.neuron_step(
op,
dt,
op_J[j],
neuron_output,
dict(zip(state_keys, [s[j] for s in op_states])),
)
mini_out.append(neuron_output)
neuron_output = np.stack(mini_out, axis=0)
# concatenate outputs
if output is None:
output = neuron_output
else:
output = np.concatenate((output, neuron_output), axis=1)
return (output,) + states
self.neuron_step = neuron_step
self.neuron_step.__name__ = utils.sanitize_name(
"_".join([repr(op.neurons) for op in self.ops])
)
[docs] def build_step(self, signals):
J = signals.gather(self.J_data)
states = [signals.gather(x) for x in self.state_data]
states_dtype = [x.dtype for x in self.state_data]
if compat.eager_enabled():
# noop
control_deps = contextlib.suppress()
else:
# we need to make sure that the previous call to this function
# has completed before the next starts, since we don't know that the
# functions are thread safe
control_deps = tf.control_dependencies(self.prev_result)
with control_deps:
ret = tf.numpy_function(
self.neuron_step,
[signals.dt, J] + states,
[self.output_data.dtype] + states_dtype,
name=self.neuron_step.__name__,
)
neuron_out, state_out = ret[0], ret[1:]
self.prev_result = [neuron_out]
neuron_out.set_shape((signals.minibatch_size,) + self.output_data.shape)
signals.scatter(self.output_data, neuron_out)
for i, s in enumerate(self.state_data):
state_out[i].set_shape((signals.minibatch_size,) + s.shape)
signals.scatter(s, state_out[i])
[docs]class TFNeuronBuilder(OpBuilder):
"""Base class for `~nengo.neurons.NeuronType` builders with a TF implementation."""
# TODO: this can be delegated to op.neurons.spiking if we increase the minimum
# Nengo version to one where that attribute is guaranteed to exist
spiking = False
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
if hasattr(self.ops[0].neurons, "amplitude"):
if all(op.neurons.amplitude == 1 for op in self.ops):
self.amplitude = None
else:
self.amplitude = signals.op_constant(
[op.neurons for op in self.ops],
[op.J.shape[0] for op in self.ops],
"amplitude",
signals.dtype,
)
else:
self.amplitude = None
self.alpha = 1 if self.amplitude is None else self.amplitude
self.alpha /= signals.dt
self.J_data = signals.combine([op.J for op in self.ops])
self.output_data = signals.combine([op.output for op in self.ops])
self.state_data = OrderedDict(
(
state,
signals.combine([compat.neuron_state(op)[state] for op in self.ops]),
)
for state in compat.neuron_state(self.ops[0])
)
[docs] def step(self, J, dt, **state):
"""Implements the logic for a single inference step."""
raise NotImplementedError("Subclasses must implement")
[docs] def training_step(self, J, dt, **state):
"""
Implements the logic for a single training step.
Note: subclasses only need to implement this if ``spiking=True``. It is used
to specify an alternate (differentiable) implementation of the neuron model
to be used during training.
"""
[docs] def build_step(self, signals, **step_kwargs):
J = signals.gather(self.J_data)
state = OrderedDict((s, signals.gather(d)) for s, d in self.state_data.items())
step_output = tf.nest.flatten(self.step(J, signals.dt, **state))
if not self.spiking or self.config.inference_only:
out = step_output
else:
out = tf.nest.flatten(
tf_utils.smart_cond(
self.config.training,
true_fn=lambda: (self.training_step(J, signals.dt, **state),)
+ tuple(state.values()),
# we use stop_gradient to avoid propagating any nans (those get
# propagated through the cond even if the spiking version isn't
# being used at all)
false_fn=lambda: tuple(
tf.stop_gradient(x) for x in tf.nest.flatten(step_output)
),
)
)
signals.scatter(self.output_data, out[0])
for state_data, v in zip(self.state_data.values(), out[1:]):
signals.scatter(state_data, v)
[docs]class RectifiedLinearBuilder(TFNeuronBuilder):
"""Build a group of `~nengo.RectifiedLinear` neuron operators."""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
if all(getattr(op.neurons, "negative_slope", 0) == 0 for op in self.ops):
self.negative_slope = None
else:
self.negative_slope = signals.op_constant(
[op.neurons for op in self.ops],
[op.J.shape[0] for op in self.ops],
"negative_slope",
signals.dtype,
)
[docs] def step(self, J, dt):
out = tf.nn.relu(J)
if self.negative_slope is not None:
out -= self.negative_slope * tf.nn.relu(-J)
if self.amplitude is not None:
out *= self.amplitude
return out
[docs]class SpikingRectifiedLinearBuilder(RectifiedLinearBuilder):
"""Build a group of `~nengo.SpikingRectifiedLinear` neuron operators."""
spiking = True
[docs] def step(self, J, dt, voltage):
if self.negative_slope is None:
voltage += tf.nn.relu(J) * dt
n_spikes = tf.floor(voltage)
else:
voltage += (tf.nn.relu(J) - self.negative_slope * tf.nn.relu(-J)) * dt
n_spikes = tf.floor(voltage) + tf.cast(voltage < 0, voltage.dtype)
voltage -= n_spikes
out = n_spikes * self.alpha
return out, voltage
[docs] def training_step(self, J, dt, **state):
return super().step(J, dt)
[docs]class SigmoidBuilder(TFNeuronBuilder):
"""Build a group of `~nengo.Sigmoid` neuron operators."""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
self.tau_ref = signals.op_constant(
[op.neurons for op in self.ops],
[op.J.shape[0] for op in self.ops],
"tau_ref",
signals.dtype,
)
[docs] def step(self, J, dt):
return tf.nn.sigmoid(J) / self.tau_ref
[docs]class TanhBuilder(TFNeuronBuilder):
"""Build a group of `~nengo.Tanh` neuron operators."""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
self.tau_ref = signals.op_constant(
[op.neurons for op in self.ops],
[op.J.shape[0] for op in self.ops],
"tau_ref",
signals.dtype,
)
[docs] def step(self, J, dt):
return tf.nn.tanh(J) / self.tau_ref
[docs]class LIFRateBuilder(TFNeuronBuilder):
"""Build a group of `~nengo.LIFRate` neuron operators."""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
self.tau_ref = signals.op_constant(
[op.neurons for op in self.ops],
[op.J.shape[0] for op in self.ops],
"tau_ref",
signals.dtype,
)
self.tau_rc = signals.op_constant(
[op.neurons for op in self.ops],
[op.J.shape[0] for op in self.ops],
"tau_rc",
signals.dtype,
)
self.zeros = tf.zeros(
(signals.minibatch_size,) + self.J_data.shape, signals.dtype
)
self.epsilon = tf.constant(1e-15, dtype=signals.dtype)
# copy these so that they're easily accessible in the step functions
self.zero = signals.zero
self.one = signals.one
[docs] def step(self, J, dt):
J -= self.one
# note: we convert all the j to be positive before this calculation
# (even though we'll only use the values that are already positive),
# otherwise we can end up with nans in the gradient
rates = (self.one if self.amplitude is None else self.amplitude) / (
self.tau_ref
+ self.tau_rc
* tf.math.log1p(tf.math.reciprocal(tf.maximum(J, self.epsilon)))
)
return tf.where(J > self.zero, rates, self.zeros)
[docs]class SoftLIFRateBuilder(LIFRateBuilder):
"""Build a group of `.SoftLIFRate` neuron operators."""
[docs] def build_pre(self, signals, config):
super().build_pre(signals, config)
self.sigma = signals.op_constant(
[op.neurons for op in self.ops],
[op.J.shape[0] for op in self.ops],
"sigma",
signals.dtype,
)
[docs] def step(self, J, dt):
J -= self.one
js = J / self.sigma
j_valid = js > -20
js_safe = tf.where(j_valid, js, self.zeros)
# softplus(js) = log(1 + e^js)
z = tf.nn.softplus(js_safe) * self.sigma
# as z->0
# z = s*log(1 + e^js) = s*e^js
# log(1 + 1/z) = log(1/z) = -log(s*e^js) = -js - log(s)
q = tf.where(
j_valid, tf.math.log1p(tf.math.reciprocal(z)), -js - tf.math.log(self.sigma)
)
rates = (self.one if self.amplitude is None else self.amplitude) / (
self.tau_ref + self.tau_rc * q
)
return rates
[docs]class LIFBuilder(SoftLIFRateBuilder):
"""Build a group of `~nengo.LIF` neuron operators."""
spiking = True
[docs] def build_pre(self, signals, config):
# note: we skip the SoftLIFRateBuilder init
# pylint: disable=bad-super-call
super(SoftLIFRateBuilder, self).build_pre(signals, config)
self.min_voltage = signals.op_constant(
[op.neurons for op in self.ops],
[op.J.shape[0] for op in self.ops],
"min_voltage",
signals.dtype,
)
if self.config.lif_smoothing:
self.sigma = tf.constant(self.config.lif_smoothing, dtype=signals.dtype)
[docs] def step(self, J, dt, voltage, refractory_time):
delta_t = tf.clip_by_value(dt - refractory_time, self.zero, dt)
dV = (voltage - J) * tf.math.expm1(-delta_t / self.tau_rc)
voltage += dV
spiked = voltage > self.one
spikes = tf.cast(spiked, J.dtype) * self.alpha
partial_ref = -self.tau_rc * tf.math.log1p(
(self.one - voltage) / (J - self.one)
)
# FastLIF version (linearly approximate spike time when calculating
# remaining refractory period)
# partial_ref = signals.dt * (voltage - self.one) / dV
refractory_time = tf.where(
spiked, self.tau_ref - partial_ref, refractory_time - dt
)
voltage = tf.where(spiked, self.zeros, tf.maximum(voltage, self.min_voltage))
return spikes, voltage, refractory_time
[docs] def training_step(self, J, dt, **state):
return (
LIFRateBuilder.step(self, J, dt)
if self.config.lif_smoothing is None
else SoftLIFRateBuilder.step(self, J, dt)
)
[docs]class RegularSpikingBuilder(TFNeuronBuilder):
"""Build a group of `~nengo.RegularSpiking` neuron operators."""
spiking = True
[docs] def step(self, J, dt, voltage):
voltage += J * dt
n_spikes = tf.floor(voltage)
voltage -= n_spikes
out = n_spikes * self.alpha
return out, voltage
[docs] def training_step(self, J, dt, **state):
return J if self.amplitude is None else J * self.amplitude
[docs]class StochasticSpikingBuilder(TFNeuronBuilder):
"""Build a group of `~nengo.StochasticSpiking` neuron operators."""
spiking = True
[docs] def step(self, J, dt):
x = dt * tf.math.abs(J)
n_spikes = tf.floor(x)
frac = x - n_spikes
n_spikes += tf.cast(tf.random.uniform(frac.shape) < frac, n_spikes.dtype)
n_spikes *= self.alpha * tf.math.sign(J)
return n_spikes
[docs] def training_step(self, J, dt):
return J if self.amplitude is None else J * self.amplitude
[docs]class PoissonSpikingBuilder(TFNeuronBuilder):
"""Build a group of `~nengo.PoissonSpiking` neuron operators."""
spiking = True
[docs] def step(self, J, dt):
n_spikes = (
self.alpha
* tf.random.poisson((), tf.math.abs(J) * dt, dtype=J.dtype)
* tf.math.sign(J)
)
n_spikes.set_shape(J.shape)
return n_spikes
[docs] def training_step(self, J, dt):
return J if self.amplitude is None else J * self.amplitude
[docs]@Builder.register(SimNeurons)
class SimNeuronsBuilder(OpBuilder):
"""
Builds a group of `~nengo.builder.neurons.SimNeurons` operators.
Calls the appropriate sub-build class for the different neuron types.
Attributes
----------
TF_NEURON_IMPL : dict of {`~nengo.neurons.NeuronType`, \
`.builder.OpBuilder`}
Mapping from neuron types to custom build classes (neurons without
a custom builder will use the generic builder).
"""
TF_NEURON_IMPL = {
RectifiedLinear: RectifiedLinearBuilder,
SpikingRectifiedLinear: SpikingRectifiedLinearBuilder,
LeakyReLU: RectifiedLinearBuilder,
SpikingLeakyReLU: SpikingRectifiedLinearBuilder,
Sigmoid: SigmoidBuilder,
compat.Tanh: TanhBuilder,
LIF: LIFBuilder,
LIFRate: LIFRateBuilder,
SoftLIFRate: SoftLIFRateBuilder,
compat.RegularSpiking: RegularSpikingBuilder,
compat.StochasticSpiking: StochasticSpikingBuilder,
compat.PoissonSpiking: PoissonSpikingBuilder,
}
def __init__(self, ops):
super().__init__(ops)
neuron_type = type(ops[0].neurons)
# if we have a custom tensorflow implementation for this neuron type,
# then we build that. otherwise we'll just execute the neuron step
# function externally (using `tf.py_func`).
if neuron_type in self.TF_NEURON_IMPL:
self.built_neurons = self.TF_NEURON_IMPL[neuron_type](ops)
else:
warnings.warn(
"%s does not have a native TensorFlow implementation; "
"falling back to Python implementation" % neuron_type
)
self.built_neurons = GenericNeuronBuilder(ops)
[docs] def build_pre(self, signals, config):
self.built_neurons.build_pre(signals, config)
[docs] def build_step(self, signals):
self.built_neurons.build_step(signals)
[docs] @staticmethod
def mergeable(x, y):
# neuron ops must all have the same type
return type(x.neurons) == type(y.neurons) # noqa: E721