"""
Build classes for Nengo neuron operators.
"""
import logging
from nengo.builder.neurons import SimNeurons
from nengo.neurons import RectifiedLinear, SpikingRectifiedLinear, Sigmoid, LIF, LIFRate
import numpy as np
import tensorflow as tf
from nengo_dl import utils
from nengo_dl.builder import Builder, OpBuilder
from nengo_dl.compat import tf_compat, tf_math
from nengo_dl.neurons import SoftLIFRate
logger = logging.getLogger(__name__)
[docs]class GenericNeuronBuilder(OpBuilder):
"""
Builds all neuron types for which there is no custom Tensorflow
implementation.
Notes
-----
These will be executed as native Python functions, requiring execution to
move in and out of TensorFlow. This can significantly slow down the
simulation, so any performance-critical neuron models should consider
adding a custom TensorFlow implementation for their neuron type instead.
"""
def __init__(self, ops, signals, config):
super(GenericNeuronBuilder, self).__init__(ops, signals, config)
self.J_data = signals.combine([op.J for op in ops])
self.output_data = signals.combine([op.output for op in ops])
self.state_data = [
signals.combine([op.states[i] for op in ops])
for i in range(len(ops[0].states))
]
self.prev_result = []
def neuron_step_math(dt, J, *states): # pragma: no cover (runs in TF)
output = None
J_offset = 0
state_offset = [0 for _ in states]
for op in ops:
# slice out the individual state vectors from the overall
# array
op_J = J[J_offset : J_offset + op.J.shape[0]]
J_offset += op.J.shape[0]
op_states = []
for j, s in enumerate(op.states):
op_states += [
states[j][state_offset[j] : state_offset[j] + s.shape[0]]
]
state_offset[j] += s.shape[0]
# call step_math function
# note: `op_states` are views into `states`, which will
# be updated in-place
mini_out = []
for j in range(signals.minibatch_size):
# blank output variable
neuron_output = np.zeros(op.output.shape, self.output_data.dtype)
op.neurons.step_math(
dt, op_J[..., j], neuron_output, *[s[..., j] for s in op_states]
)
mini_out += [neuron_output]
neuron_output = np.stack(mini_out, axis=-1)
# concatenate outputs
if output is None:
output = neuron_output
else:
output = np.concatenate((output, neuron_output), axis=0)
return (output,) + states
self.neuron_step_math = neuron_step_math
self.neuron_step_math.__name__ = utils.sanitize_name(
"_".join([repr(op.neurons) for op in ops])
)
[docs] def build_step(self, signals):
J = signals.gather(self.J_data)
states = [signals.gather(x) for x in self.state_data]
states_dtype = [x.dtype for x in self.state_data]
# note: we need to make sure that the previous call to this function
# has completed before the next starts, since we don't know that the
# functions are thread safe
with tf.control_dependencies(self.prev_result), tf.device("/cpu:0"):
ret = tf_compat.py_func(
self.neuron_step_math,
[signals.dt, J] + states,
[self.output_data.dtype] + states_dtype,
name=self.neuron_step_math.__name__,
)
neuron_out, state_out = ret[0], ret[1:]
self.prev_result = [neuron_out]
neuron_out.set_shape(self.output_data.shape + (signals.minibatch_size,))
signals.scatter(self.output_data, neuron_out)
for i, s in enumerate(self.state_data):
state_out[i].set_shape(s.shape + (signals.minibatch_size,))
signals.scatter(s, state_out[i])
[docs]class RectifiedLinearBuilder(OpBuilder):
"""Build a group of `~nengo.RectifiedLinear`
neuron operators."""
def __init__(self, ops, signals, config):
super(RectifiedLinearBuilder, self).__init__(ops, signals, config)
self.J_data = signals.combine([op.J for op in ops])
self.output_data = signals.combine([op.output for op in ops])
if all(op.neurons.amplitude == 1 for op in ops):
self.amplitude = None
else:
self.amplitude = signals.op_constant(
[op.neurons for op in ops],
[op.J.shape[0] for op in ops],
"amplitude",
signals.dtype,
)
def _step(self, J):
out = tf.nn.relu(J)
if self.amplitude is not None:
out *= self.amplitude
return out
[docs] def build_step(self, signals):
J = signals.gather(self.J_data)
out = self._step(J)
signals.scatter(self.output_data, out)
[docs]class SpikingRectifiedLinearBuilder(RectifiedLinearBuilder):
"""Build a group of `~nengo.SpikingRectifiedLinear` neuron
operators."""
def __init__(self, ops, signals, config):
super(SpikingRectifiedLinearBuilder, self).__init__(ops, signals, config)
self.voltage_data = signals.combine([op.states[0] for op in ops])
self.alpha = 1 if self.amplitude is None else self.amplitude
self.alpha /= signals.dt
def _step(self, J, voltage, dt):
voltage += tf.nn.relu(J) * dt
n_spikes = tf.floor(voltage)
voltage -= n_spikes
out = n_spikes * self.alpha
# we use stop_gradient to avoid propagating any nans (those get
# propagated through the cond even if the spiking version isn't
# being used at all)
return tf.stop_gradient(out), tf.stop_gradient(voltage)
[docs] def build_step(self, signals):
J = signals.gather(self.J_data)
voltage = signals.gather(self.voltage_data)
spike_out, spike_voltage = self._step(J, voltage, signals.dt)
if self.config.inference_only:
out, voltage = spike_out, spike_voltage
else:
rate_out = super(SpikingRectifiedLinearBuilder, self)._step(J)
out, voltage = tf.cond(
pred=signals.training,
true_fn=lambda: (rate_out, voltage),
false_fn=lambda: (spike_out, spike_voltage),
)
signals.scatter(self.output_data, out)
signals.scatter(self.voltage_data, voltage)
[docs]class SigmoidBuilder(OpBuilder):
"""Build a group of `~nengo.Sigmoid` neuron operators."""
def __init__(self, ops, signals, config):
super(SigmoidBuilder, self).__init__(ops, signals, config)
self.J_data = signals.combine([op.J for op in ops])
self.output_data = signals.combine([op.output for op in ops])
self.tau_ref = signals.op_constant(
[op.neurons for op in ops],
[op.J.shape[0] for op in ops],
"tau_ref",
signals.dtype,
)
[docs] def build_step(self, signals):
J = signals.gather(self.J_data)
signals.scatter(self.output_data, tf.nn.sigmoid(J) / self.tau_ref)
[docs]class LIFRateBuilder(OpBuilder):
"""Build a group of `~nengo.LIFRate` neuron operators."""
def __init__(self, ops, signals, config):
super(LIFRateBuilder, self).__init__(ops, signals, config)
self.tau_ref = signals.op_constant(
[op.neurons for op in ops],
[op.J.shape[0] for op in ops],
"tau_ref",
signals.dtype,
)
self.tau_rc = signals.op_constant(
[op.neurons for op in ops],
[op.J.shape[0] for op in ops],
"tau_rc",
signals.dtype,
)
self.amplitude = signals.op_constant(
[op.neurons for op in ops],
[op.J.shape[0] for op in ops],
"amplitude",
signals.dtype,
)
self.J_data = signals.combine([op.J for op in ops])
self.output_data = signals.combine([op.output for op in ops])
self.zeros = tf.zeros(
self.J_data.shape + (signals.minibatch_size,), signals.dtype
)
self.epsilon = tf.constant(1e-15, dtype=signals.dtype)
# copy these so that they're easily accessible in the _step functions
self.zero = signals.zero
self.one = signals.one
def _step(self, j):
j -= self.one
# note: we convert all the j to be positive before this calculation
# (even though we'll only use the values that are already positive),
# otherwise we can end up with nans in the gradient
rates = self.amplitude / (
self.tau_ref
+ self.tau_rc
* tf_math.log1p(tf_math.reciprocal(tf.maximum(j, self.epsilon)))
)
return tf.where(j > self.zero, rates, self.zeros)
[docs] def build_step(self, signals):
j = signals.gather(self.J_data)
rates = self._step(j)
signals.scatter(self.output_data, rates)
[docs]class SoftLIFRateBuilder(LIFRateBuilder):
"""Build a group of `.SoftLIFRate` neuron operators."""
def __init__(self, ops, signals, config):
super(SoftLIFRateBuilder, self).__init__(ops, signals, config)
self.sigma = signals.op_constant(
[op.neurons for op in ops],
[op.J.shape[0] for op in ops],
"sigma",
signals.dtype,
)
def _step(self, J):
J -= self.one
js = J / self.sigma
j_valid = js > -20
js_safe = tf.where(j_valid, js, self.zeros)
# softplus(js) = log(1 + e^js)
z = tf.nn.softplus(js_safe) * self.sigma
# as z->0
# z = s*log(1 + e^js) = s*e^js
# log(1 + 1/z) = log(1/z) = -log(s*e^js) = -js - log(s)
q = tf.where(
j_valid, tf_math.log1p(tf_math.reciprocal(z)), -js - tf_math.log(self.sigma)
)
rates = self.amplitude / (self.tau_ref + self.tau_rc * q)
return rates
[docs] def build_step(self, signals):
j = signals.gather(self.J_data)
rates = self._step(j)
signals.scatter(self.output_data, rates)
[docs]class LIFBuilder(SoftLIFRateBuilder):
"""Build a group of `~nengo.LIF` neuron operators."""
def __init__(self, ops, signals, config):
# note: we skip the SoftLIFRateBuilder init
# pylint: disable=bad-super-call
super(SoftLIFRateBuilder, self).__init__(ops, signals, config)
self.min_voltage = signals.op_constant(
[op.neurons for op in ops],
[op.J.shape[0] for op in ops],
"min_voltage",
signals.dtype,
)
self.alpha = self.amplitude / signals.dt
self.voltage_data = signals.combine([op.states[0] for op in ops])
self.refractory_data = signals.combine([op.states[1] for op in ops])
if self.config.lif_smoothing:
self.sigma = tf.constant(self.config.lif_smoothing, dtype=signals.dtype)
def _step(self, J, voltage, refractory, dt):
delta_t = tf.clip_by_value(dt - refractory, self.zero, dt)
dV = (voltage - J) * tf_math.expm1(-delta_t / self.tau_rc)
voltage += dV
spiked = voltage > self.one
spikes = tf.cast(spiked, J.dtype) * self.alpha
partial_ref = -self.tau_rc * tf_math.log1p(
(self.one - voltage) / (J - self.one)
)
# FastLIF version (linearly approximate spike time when calculating
# remaining refractory period)
# partial_ref = signals.dt * (voltage - self.one) / dV
refractory = tf.where(spiked, self.tau_ref - partial_ref, refractory - dt)
voltage = tf.where(spiked, self.zeros, tf.maximum(voltage, self.min_voltage))
# we use stop_gradient to avoid propagating any nans (those get
# propagated through the cond even if the spiking version isn't
# being used at all)
return (
tf.stop_gradient(spikes),
tf.stop_gradient(voltage),
tf.stop_gradient(refractory),
)
[docs] def build_step(self, signals):
J = signals.gather(self.J_data)
voltage = signals.gather(self.voltage_data)
refractory = signals.gather(self.refractory_data)
spike_out, spike_voltage, spike_ref = self._step(
J, voltage, refractory, signals.dt
)
if self.config.inference_only:
spikes, voltage, refractory = spike_out, spike_voltage, spike_ref
else:
rate_out = (
LIFRateBuilder._step(self, J)
if self.config.lif_smoothing is None
else SoftLIFRateBuilder._step(self, J)
)
spikes, voltage, refractory = tf.cond(
pred=signals.training,
true_fn=lambda: (rate_out, voltage, refractory),
false_fn=lambda: (spike_out, spike_voltage, spike_ref),
)
signals.scatter(self.output_data, spikes)
signals.mark_gather(self.J_data)
signals.scatter(self.refractory_data, refractory)
signals.scatter(self.voltage_data, voltage)
[docs]@Builder.register(SimNeurons)
class SimNeuronsBuilder(OpBuilder):
"""
Builds a group of `~nengo.builder.neurons.SimNeurons` operators.
Calls the appropriate sub-build class for the different neuron types.
Attributes
----------
TF_NEURON_IMPL : dict of {`~nengo.neurons.NeuronType`, \
`.builder.OpBuilder`}
Mapping from neuron types to custom build classes (neurons without
a custom builder will use the generic builder).
"""
TF_NEURON_IMPL = {
RectifiedLinear: RectifiedLinearBuilder,
SpikingRectifiedLinear: SpikingRectifiedLinearBuilder,
Sigmoid: SigmoidBuilder,
LIF: LIFBuilder,
LIFRate: LIFRateBuilder,
SoftLIFRate: SoftLIFRateBuilder,
}
def __init__(self, ops, signals, config):
super(SimNeuronsBuilder, self).__init__(ops, signals, config)
logger.debug("J %s", [op.J for op in ops])
neuron_type = type(ops[0].neurons)
# if we have a custom tensorflow implementation for this neuron type,
# then we build that. otherwise we'll just execute the neuron step
# function externally (using `tf.py_func`).
if neuron_type in self.TF_NEURON_IMPL:
self.built_neurons = self.TF_NEURON_IMPL[neuron_type](ops, signals, config)
else:
self.built_neurons = GenericNeuronBuilder(ops, signals, config)
[docs] def build_step(self, signals):
self.built_neurons.build_step(signals)
[docs] @staticmethod
def mergeable(x, y):
# neuron ops must all have the same type
return type(x.neurons) == type(y.neurons) # noqa: E721