Source code for nengo_dl.neurons

import logging

from nengo.neurons import RectifiedLinear, Sigmoid, LIF, LIFRate
from nengo.builder.neurons import SimNeurons
from nengo.params import NumberParam
import numpy as np
import tensorflow as tf

from nengo_dl import utils
from nengo_dl.builder import Builder, OpBuilder

logger = logging.getLogger(__name__)


[docs]class SoftLIFRate(LIFRate): """LIF neuron with smoothing around the firing threshold. This is a rate version of the LIF neuron whose tuning curve has a continuous first derivative, due to the smoothing around the firing threshold. It can be used as a substitute for LIF neurons in deep networks during training, and then replaced with LIF neurons when running the network [1]_. Parameters ---------- sigma : float Amount of smoothing around the firing threshold. Larger values mean more smoothing. tau_rc : float Membrane RC time constant, in seconds. Affects how quickly the membrane voltage decays to zero in the absence of input (larger = slower decay). tau_ref : float Absolute refractory period, in seconds. This is how long the membrane voltage is held at zero after a spike. References ---------- .. [1] E. Hunsberger & C. Eliasmith (2015). Spiking Deep Networks with LIF Neurons. arXiv Preprint, 1510. http://arxiv.org/abs/1510.08829 Notes ----- Adapted from https://github.com/nengo/nengo_extras/blob/master/nengo_extras/neurons.py """ sigma = NumberParam('sigma', low=0, low_open=True) def __init__(self, sigma=1., **lif_args): super(SoftLIFRate, self).__init__(**lif_args) self.sigma = sigma @property def _argreprs(self): args = super(SoftLIFRate, self)._argreprs if self.sigma != 1.: args.append("sigma=%s" % self.sigma) return args
[docs] def rates(self, x, gain, bias): J = gain * x J += bias out = np.zeros_like(J) self.step_math(dt=1, J=J, output=out) return out
[docs] def step_math(self, dt, J, output): """Compute rates in Hz for input current (incl. bias)""" x = J - 1 y = x / self.sigma valid = y < 34 y_v = y[valid] np.exp(y_v, out=y_v) np.log1p(y_v, out=y_v) y_v *= self.sigma x[valid] = y_v output[:] = 0 output[x > 0] = 1. / ( self.tau_ref + self.tau_rc * np.log1p(1. / x[x > 0]))
[docs]@Builder.register(SimNeurons) class SimNeuronsBuilder(OpBuilder): """Builds a group of :class:`~nengo:nengo.builder.neurons.SimNeurons` operators. Calls the appropriate sub-build class for the different neuron types. Attributes ---------- TF_NEURON_IMPL : list of :class:`~nengo:nengo.neurons.NeuronType` the neuron types that have a custom implementation """ TF_NEURON_IMPL = (RectifiedLinear, Sigmoid, LIF, LIFRate, SoftLIFRate) def __init__(self, ops, signals): logger.debug("sim_neurons") logger.debug([op for op in ops]) logger.debug("J %s", [op.J for op in ops]) neuron_type = type(ops[0].neurons) # if we have a custom tensorflow implementation for this neuron type, # then we build that. otherwise we'll just execute the neuron step # function externally (using `tf.py_func`), so we just need to set up # the inputs/outputs for that. if neuron_type in self.TF_NEURON_IMPL: # note: we do this two-step check (even though it's redundant) to # make sure that TF_NEURON_IMPL is kept up to date if neuron_type == RectifiedLinear: self.built_neurons = RectifiedLinearBuilder(ops, signals) if neuron_type == Sigmoid: self.built_neurons = SigmoidBuilder(ops, signals) elif neuron_type == LIFRate: self.built_neurons = LIFRateBuilder(ops, signals) elif neuron_type == LIF: self.built_neurons = LIFBuilder(ops, signals) elif neuron_type == SoftLIFRate: self.built_neurons = SoftLIFRateBuilder(ops, signals) else: self.built_neurons = GenericNeuronBuilder(ops, signals)
[docs] def build_step(self, signals): self.built_neurons.build_step(signals)
[docs]class GenericNeuronBuilder(object): """Builds all neuron types for which there is no custom Tensorflow implementation. Notes ----- These will be executed as native Python functions, requiring execution to move in and out of Tensorflow. This can significantly slow down the simulation, so any performance-critical neuron models should consider adding a custom Tensorflow implementation for their neuron type instead. """ def __init__(self, ops, signals): self.J_data = signals.combine([op.J for op in ops]) self.output_data = signals.combine([op.output for op in ops]) self.state_data = [signals.combine([op.states[i] for op in ops]) for i in range(len(ops[0].states))] self.prev_result = [] def neuron_step_math(dt, J, *states): # pragma: no cover output = None J_offset = 0 state_offset = [0 for _ in states] for i, op in enumerate(ops): # slice out the individual state vectors from the overall # array op_J = J[J_offset:J_offset + op.J.shape[0]] J_offset += op.J.shape[0] op_states = [] for j, s in enumerate(op.states): op_states += [states[j][state_offset[j]: state_offset[j] + s.shape[0]]] state_offset[j] += s.shape[0] # call step_math function # note: `op_states` are views into `states`, which will # be updated in-place mini_out = [] for j in range(signals.minibatch_size): # blank output variable neuron_output = np.zeros( op.output.shape, self.output_data.dtype) op.neurons.step_math(dt, op_J[..., j], neuron_output, *[s[..., j] for s in op_states]) mini_out += [neuron_output] neuron_output = np.stack(mini_out, axis=-1) # concatenate outputs if output is None: output = neuron_output else: output = np.concatenate((output, neuron_output), axis=0) return (output,) + states self.neuron_step_math = neuron_step_math self.neuron_step_math.__name__ = utils.sanitize_name( "_".join([repr(op.neurons) for op in ops])) def build_step(self, signals): J = signals.gather(self.J_data) states = [signals.gather(x) for x in self.state_data] states_dtype = [x.dtype for x in self.state_data] # note: we need to make sure that the previous call to this function # has completed before the next starts, since we don't know that the # functions are thread safe with tf.control_dependencies(self.prev_result), tf.device("/cpu:0"): ret = tf.py_func( self.neuron_step_math, [signals.dt, J] + states, [self.output_data.dtype] + states_dtype, name=self.neuron_step_math.__name__) neuron_out, state_out = ret[0], ret[1:] self.prev_result = [neuron_out] neuron_out.set_shape( self.output_data.shape + (signals.minibatch_size,)) signals.scatter(self.output_data, neuron_out) for i, s in enumerate(self.state_data): state_out[i].set_shape(s.shape + (signals.minibatch_size,)) signals.scatter(s, state_out[i])
[docs]class RectifiedLinearBuilder(object): """Build a group of :class:`~nengo:nengo.RectifiedLinear` neuron operators.""" def __init__(self, ops, signals): self.J_data = signals.combine([op.J for op in ops]) self.output_data = signals.combine([op.output for op in ops]) def build_step(self, signals): J = signals.gather(self.J_data) signals.scatter(self.output_data, tf.nn.relu(J))
[docs]class SigmoidBuilder(object): """Build a group of :class:`~nengo:nengo.Sigmoid` neuron operators.""" def __init__(self, ops, signals): self.J_data = signals.combine([op.J for op in ops]) self.output_data = signals.combine([op.output for op in ops]) self.tau_ref = tf.constant( [[op.neurons.tau_ref] for op in ops for _ in range(op.J.shape[0])], dtype=signals.dtype) def build_step(self, signals): J = signals.gather(self.J_data) signals.scatter(self.output_data, tf.nn.sigmoid(J) / self.tau_ref)
[docs]class LIFRateBuilder(object): """Build a group of :class:`~nengo:nengo.LIFRate` neuron operators.""" def __init__(self, ops, signals): self.tau_ref = tf.constant( [[op.neurons.tau_ref] for op in ops for _ in range(op.J.shape[0])], dtype=signals.dtype) self.tau_rc = tf.constant( [[op.neurons.tau_rc] for op in ops for _ in range(op.J.shape[0])], dtype=signals.dtype) self.J_data = signals.combine([op.J for op in ops]) self.output_data = signals.combine([op.output for op in ops]) self.zeros = tf.zeros(self.J_data.shape + (signals.minibatch_size,), signals.dtype) def build_step(self, signals, j=None): if j is None: j = signals.gather(self.J_data) - 1 # indices = tf.cast(tf.where(j > 0), tf.int32) # tau_ref = tf.gather_nd( # self.tau_ref, tf.expand_dims(indices[:, 0], 1)) # tau_rc = tf.gather_nd(self.tau_rc, tf.expand_dims(indices[:, 0], 1)) # j = tf.gather_nd(j, indices) # # signals.scatter( # self.output_data, # tf.scatter_nd(indices, 1 / (tau_ref + tau_rc * tf.log1p(1 / j)), # tf.shape(J))) rates = 1 / (self.tau_ref + self.tau_rc * tf.log1p(1 / j)) signals.scatter(self.output_data, tf.where(j > 0, rates, self.zeros))
[docs]class LIFBuilder(object): """Build a group of :class:`~nengo:nengo.LIF` neuron operators.""" def __init__(self, ops, signals): self.tau_ref = tf.constant( [[op.neurons.tau_ref] for op in ops for _ in range(op.J.shape[0])], dtype=signals.dtype) self.tau_rc = tf.constant( [[op.neurons.tau_rc] for op in ops for _ in range(op.J.shape[0])], dtype=signals.dtype) self.min_voltage = tf.constant( [[op.neurons.min_voltage] for op in ops for _ in range(op.J.shape[0])], dtype=signals.dtype) self.J_data = signals.combine([op.J for op in ops]) self.output_data = signals.combine([op.output for op in ops]) self.voltage_data = signals.combine([op.states[0] for op in ops]) self.refractory_data = signals.combine([op.states[1] for op in ops]) self.zeros = tf.zeros(self.J_data.shape + (signals.minibatch_size,), signals.dtype) def build_step(self, signals): J = signals.gather(self.J_data) voltage = signals.gather(self.voltage_data) refractory = signals.gather(self.refractory_data) refractory -= signals.dt delta_t = tf.clip_by_value(signals.dt - refractory, 0, signals.dt) voltage -= (J - voltage) * (tf.exp(-delta_t / self.tau_rc) - 1) spiked = voltage > 1 spikes = tf.cast(spiked, signals.dtype) / signals.dt signals.scatter(self.output_data, spikes) # note: this scatter/gather approach is slower than just doing the # computation on the whole array (even though we're not using the # result for any of the neurons that didn't spike). # this is because there is no GPU kernel for scatter/gather_nd. so if # that gets implemented in the future, this may be faster. # indices = tf.cast(tf.where(spiked), tf.int32) # indices0 = tf.expand_dims(indices[:, 0], 1) # tau_rc = tf.gather_nd(self.tau_rc, indices0) # tau_ref = tf.gather_nd(self.tau_ref, indices0) # t_spike = tau_ref + signals.dt + tau_rc * tf.log1p( # -(tf.gather_nd(voltage, indices) - 1) / # (tf.gather_nd(J, indices) - 1)) # refractory = tf.where( # spiked, tf.scatter_nd(indices, t_spike, tf.shape(refractory)), # refractory) t_spike = (self.tau_ref + signals.dt + self.tau_rc * tf.log1p((1 - voltage) / (J - 1))) refractory = tf.where(spiked, t_spike, refractory) signals.mark_gather(self.J_data) signals.scatter(self.refractory_data, refractory) voltage = tf.where(spiked, self.zeros, tf.maximum(voltage, self.min_voltage)) signals.scatter(self.voltage_data, voltage)
[docs]class SoftLIFRateBuilder(LIFRateBuilder): def __init__(self, ops, signals): super(SoftLIFRateBuilder, self).__init__(ops, signals) self.sigma = tf.constant( [[op.neurons.sigma] for op in ops for _ in range(op.J.shape[0])], dtype=signals.dtype) def build_step(self, signals): x = signals.gather(self.J_data) - 1 y = x / self.sigma z = tf.where(y < 34, self.sigma * tf.log1p(tf.exp(y)), x) super(SoftLIFRateBuilder, self).build_step(signals, j=z)