import logging
from nengo.builder.neurons import SimNeurons
from nengo.neurons import (RectifiedLinear, SpikingRectifiedLinear, Sigmoid,
LIF, LIFRate)
import numpy as np
import tensorflow as tf
from nengo_dl import utils
from nengo_dl.builder import Builder, OpBuilder
from nengo_dl.neurons import SoftLIFRate
logger = logging.getLogger(__name__)
[docs]class GenericNeuronBuilder(OpBuilder):
"""
Builds all neuron types for which there is no custom Tensorflow
implementation.
Notes
-----
These will be executed as native Python functions, requiring execution to
move in and out of TensorFlow. This can significantly slow down the
simulation, so any performance-critical neuron models should consider
adding a custom TensorFlow implementation for their neuron type instead.
"""
def __init__(self, ops, signals, config):
super(GenericNeuronBuilder, self).__init__(ops, signals, config)
self.J_data = signals.combine([op.J for op in ops])
self.output_data = signals.combine([op.output for op in ops])
self.state_data = [signals.combine([op.states[i] for op in ops])
for i in range(len(ops[0].states))]
self.prev_result = []
def neuron_step_math(dt, J, *states): # pragma: no cover
output = None
J_offset = 0
state_offset = [0 for _ in states]
for op in ops:
# slice out the individual state vectors from the overall
# array
op_J = J[J_offset:J_offset + op.J.shape[0]]
J_offset += op.J.shape[0]
op_states = []
for j, s in enumerate(op.states):
op_states += [states[j][state_offset[j]:
state_offset[j] + s.shape[0]]]
state_offset[j] += s.shape[0]
# call step_math function
# note: `op_states` are views into `states`, which will
# be updated in-place
mini_out = []
for j in range(signals.minibatch_size):
# blank output variable
neuron_output = np.zeros(
op.output.shape, self.output_data.dtype)
op.neurons.step_math(dt, op_J[..., j], neuron_output,
*[s[..., j] for s in op_states])
mini_out += [neuron_output]
neuron_output = np.stack(mini_out, axis=-1)
# concatenate outputs
if output is None:
output = neuron_output
else:
output = np.concatenate((output, neuron_output),
axis=0)
return (output,) + states
self.neuron_step_math = neuron_step_math
self.neuron_step_math.__name__ = utils.sanitize_name(
"_".join([repr(op.neurons) for op in ops]))
[docs] def build_step(self, signals):
J = signals.gather(self.J_data)
states = [signals.gather(x) for x in self.state_data]
states_dtype = [x.dtype for x in self.state_data]
# note: we need to make sure that the previous call to this function
# has completed before the next starts, since we don't know that the
# functions are thread safe
with tf.control_dependencies(self.prev_result), tf.device("/cpu:0"):
ret = tf.py_func(
self.neuron_step_math, [signals.dt, J] + states,
[self.output_data.dtype] + states_dtype,
name=self.neuron_step_math.__name__)
neuron_out, state_out = ret[0], ret[1:]
self.prev_result = [neuron_out]
neuron_out.set_shape(
self.output_data.shape + (signals.minibatch_size,))
signals.scatter(self.output_data, neuron_out)
for i, s in enumerate(self.state_data):
state_out[i].set_shape(s.shape + (signals.minibatch_size,))
signals.scatter(s, state_out[i])
[docs]class RectifiedLinearBuilder(OpBuilder):
"""Build a group of :class:`~nengo:nengo.RectifiedLinear`
neuron operators."""
def __init__(self, ops, signals, config):
super(RectifiedLinearBuilder, self).__init__(ops, signals, config)
self.J_data = signals.combine([op.J for op in ops])
self.output_data = signals.combine([op.output for op in ops])
if all(op.neurons.amplitude == 1 for op in ops):
self.amplitude = None
else:
self.amplitude = signals.op_constant(
[op.neurons for op in ops], [op.J.shape[0] for op in ops],
"amplitude", signals.dtype)
def _step(self, J):
out = tf.nn.relu(J)
if self.amplitude is not None:
out *= self.amplitude
return out
[docs] def build_step(self, signals):
J = signals.gather(self.J_data)
out = self._step(J)
signals.scatter(self.output_data, out)
[docs]class SpikingRectifiedLinearBuilder(RectifiedLinearBuilder):
"""Build a group of :class:`~nengo:nengo.SpikingRectifiedLinear` neuron
operators."""
def __init__(self, ops, signals, config):
super(SpikingRectifiedLinearBuilder, self).__init__(
ops, signals, config)
self.voltage_data = signals.combine([op.states[0] for op in ops])
self.alpha = 1 if self.amplitude is None else self.amplitude
self.alpha /= signals.dt
def _step(self, J, voltage, dt):
voltage += tf.nn.relu(J) * dt
n_spikes = tf.floor(voltage)
voltage -= n_spikes
out = n_spikes * self.alpha
# we use stop_gradient to avoid propagating any nans (those get
# propagated through the cond even if the spiking version isn't
# being used at all)
return tf.stop_gradient(out), tf.stop_gradient(voltage)
[docs] def build_step(self, signals):
J = signals.gather(self.J_data)
voltage = signals.gather(self.voltage_data)
spike_out, spike_voltage = self._step(J, voltage, signals.dt)
if self.config.inference_only:
out, voltage = spike_out, spike_voltage
else:
rate_out = super(SpikingRectifiedLinearBuilder, self)._step(J)
out, voltage = tf.cond(
signals.training,
lambda: (rate_out, voltage),
lambda: (spike_out, spike_voltage))
signals.scatter(self.output_data, out)
signals.scatter(self.voltage_data, voltage)
[docs]class SigmoidBuilder(OpBuilder):
"""Build a group of :class:`~nengo:nengo.Sigmoid` neuron operators."""
def __init__(self, ops, signals, config):
super(SigmoidBuilder, self).__init__(ops, signals, config)
self.J_data = signals.combine([op.J for op in ops])
self.output_data = signals.combine([op.output for op in ops])
self.tau_ref = signals.op_constant(
[op.neurons for op in ops], [op.J.shape[0] for op in ops],
"tau_ref", signals.dtype)
[docs] def build_step(self, signals):
J = signals.gather(self.J_data)
signals.scatter(self.output_data, tf.nn.sigmoid(J) / self.tau_ref)
[docs]class LIFRateBuilder(OpBuilder):
"""Build a group of :class:`~nengo:nengo.LIFRate` neuron operators."""
def __init__(self, ops, signals, config):
super(LIFRateBuilder, self).__init__(ops, signals, config)
self.tau_ref = signals.op_constant(
[op.neurons for op in ops], [op.J.shape[0] for op in ops],
"tau_ref", signals.dtype)
self.tau_rc = signals.op_constant(
[op.neurons for op in ops], [op.J.shape[0] for op in ops],
"tau_rc", signals.dtype)
self.amplitude = signals.op_constant(
[op.neurons for op in ops], [op.J.shape[0] for op in ops],
"amplitude", signals.dtype)
self.J_data = signals.combine([op.J for op in ops])
self.output_data = signals.combine([op.output for op in ops])
self.zeros = tf.zeros(self.J_data.shape + (signals.minibatch_size,),
signals.dtype)
self.epsilon = tf.constant(1e-15, dtype=signals.dtype)
# copy these so that they're easily accessible in the _step functions
self.zero = signals.zero
self.one = signals.one
def _step(self, j):
j -= self.one
# note: we convert all the j to be positive before this calculation
# (even though we'll only use the values that are already positive),
# otherwise we can end up with nans in the gradient
rates = self.amplitude / (
self.tau_ref + self.tau_rc * tf.log1p(tf.reciprocal(
tf.maximum(j, self.epsilon))))
return tf.where(j > self.zero, rates, self.zeros)
[docs] def build_step(self, signals):
j = signals.gather(self.J_data)
rates = self._step(j)
signals.scatter(self.output_data, rates)
[docs]class SoftLIFRateBuilder(LIFRateBuilder):
"""Build a group of :class:`.SoftLIFRate` neuron operators."""
def __init__(self, ops, signals, config):
super(SoftLIFRateBuilder, self).__init__(ops, signals, config)
self.sigma = signals.op_constant(
[op.neurons for op in ops], [op.J.shape[0] for op in ops],
"sigma", signals.dtype)
def _step(self, J):
J -= self.one
js = J / self.sigma
j_valid = js > -20
js_safe = tf.where(j_valid, js, self.zeros)
# softplus(js) = log(1 + e^js)
z = tf.nn.softplus(js_safe) * self.sigma
# as z->0
# z = s*log(1 + e^js) = s*e^js
# log(1 + 1/z) = log(1/z) = -log(s*e^js) = -js - log(s)
q = tf.where(j_valid,
tf.log1p(tf.reciprocal(z)),
-js - tf.log(self.sigma))
rates = self.amplitude / (self.tau_ref + self.tau_rc * q)
return rates
[docs] def build_step(self, signals):
j = signals.gather(self.J_data)
rates = self._step(j)
signals.scatter(self.output_data, rates)
[docs]class LIFBuilder(SoftLIFRateBuilder):
"""Build a group of :class:`~nengo:nengo.LIF` neuron operators."""
def __init__(self, ops, signals, config):
# note: we skip the SoftLIFRateBuilder init
# pylint: disable=bad-super-call
super(SoftLIFRateBuilder, self).__init__(ops, signals, config)
self.min_voltage = signals.op_constant(
[op.neurons for op in ops], [op.J.shape[0] for op in ops],
"min_voltage", signals.dtype)
self.alpha = self.amplitude / signals.dt
self.voltage_data = signals.combine([op.states[0] for op in ops])
self.refractory_data = signals.combine([op.states[1] for op in ops])
if self.config.lif_smoothing:
self.sigma = tf.constant(self.config.lif_smoothing,
dtype=signals.dtype)
def _step(self, J, voltage, refractory, dt):
delta_t = tf.clip_by_value(dt - refractory, self.zero, dt)
dV = (voltage - J) * tf.expm1(-delta_t / self.tau_rc)
voltage += dV
spiked = voltage > self.one
spikes = tf.cast(spiked, J.dtype) * self.alpha
partial_ref = -self.tau_rc * tf.log1p((self.one - voltage) /
(J - self.one))
# FastLIF version (linearly approximate spike time when calculating
# remaining refractory period)
# partial_ref = signals.dt * (voltage - self.one) / dV
refractory = tf.where(spiked, self.tau_ref - partial_ref,
refractory - dt)
voltage = tf.where(spiked, self.zeros,
tf.maximum(voltage, self.min_voltage))
# we use stop_gradient to avoid propagating any nans (those get
# propagated through the cond even if the spiking version isn't
# being used at all)
return (tf.stop_gradient(spikes), tf.stop_gradient(voltage),
tf.stop_gradient(refractory))
[docs] def build_step(self, signals):
J = signals.gather(self.J_data)
voltage = signals.gather(self.voltage_data)
refractory = signals.gather(self.refractory_data)
spike_out, spike_voltage, spike_ref = self._step(
J, voltage, refractory, signals.dt)
if self.config.inference_only:
spikes, voltage, refractory = spike_out, spike_voltage, spike_ref
else:
rate_out = (LIFRateBuilder._step(self, J)
if self.config.lif_smoothing is None else
SoftLIFRateBuilder._step(self, J))
spikes, voltage, refractory = tf.cond(
signals.training,
lambda: (rate_out, voltage, refractory),
lambda: (spike_out, spike_voltage, spike_ref)
)
signals.scatter(self.output_data, spikes)
signals.mark_gather(self.J_data)
signals.scatter(self.refractory_data, refractory)
signals.scatter(self.voltage_data, voltage)
[docs]@Builder.register(SimNeurons)
class SimNeuronsBuilder(OpBuilder):
"""
Builds a group of :class:`~nengo:nengo.builder.neurons.SimNeurons`
operators.
Calls the appropriate sub-build class for the different neuron types.
Attributes
----------
TF_NEURON_IMPL : dict of {:class:`~nengo:nengo.neurons.NeuronType`, \
:class:`.builder.OpBuilder`}
Mapping from neuron types to custom build classes (neurons without
a custom builder will use the generic builder).
"""
TF_NEURON_IMPL = {
RectifiedLinear: RectifiedLinearBuilder,
SpikingRectifiedLinear: SpikingRectifiedLinearBuilder,
Sigmoid: SigmoidBuilder,
LIF: LIFBuilder,
LIFRate: LIFRateBuilder,
SoftLIFRate: SoftLIFRateBuilder,
}
def __init__(self, ops, signals, config):
super(SimNeuronsBuilder, self).__init__(ops, signals, config)
logger.debug("J %s", [op.J for op in ops])
neuron_type = type(ops[0].neurons)
# if we have a custom tensorflow implementation for this neuron type,
# then we build that. otherwise we'll just execute the neuron step
# function externally (using `tf.py_func`).
if neuron_type in self.TF_NEURON_IMPL:
self.built_neurons = self.TF_NEURON_IMPL[neuron_type](
ops, signals, config)
else:
self.built_neurons = GenericNeuronBuilder(ops, signals, config)
[docs] def build_step(self, signals):
self.built_neurons.build_step(signals)