Source code for nengo_spa.networks.selection

"""Selection networks that pick one or more options among multiple choices."""
import nengo
import numpy as np


[docs]class IA(nengo.Network): """Independent accumulator (IA) winner-take-all (WTA) network. This is a two-layered network. The first layer consists of independent accumulators (integrators), whereas the second layer does a thresholding. Once the threshold is exceeded a feedback connection will stabilize the current choice and inhibit all other choices. To switch the selection, it is necessary to provide a transient input to *input_reset* to reset the accumulator states. This network is suited especially for accumulating evidence under noisy conditions and keep a stable choice selection until the processing of the choice has been finished. Further details are to be found in [gosmann2017]_. Parameters ---------- n_neurons : int Number of neurons for each choice. n_ensembles : int Number of choices. accum_threshold : float, optional Accumulation threshold that needs to be reached to produce an output. accum_neuron_ratio: float, optional Portion of *n_neurons* that will be used for a layer 1 accumulator ensemble. The remaining neurons will be used for a layer 2 thresholding ensemble. accum_timescale : float, optional Evidence accumulation timescale. feedback_timescale : float, optional Timescale for the feedback connection from the thresholding layer to the accumulation layer. accum_synapse : Synapse or float, optional The synapse for connections to the accumulator ensembles. ff_synapse : Synapse or float, optional Synapse for feed-forward connections. intercept_width : float, optional The `nengo.presets.ThresholdingEnsembles` *intercept_width* parameter. radius : float, optional The representational radius of the ensembles. **kwargs : dict Keyword arguments passed on to `nengo.Network`. Attributes ---------- input : nengo.Node The inputs to the network. input_reset : nengo.Node Input to reset the accumulators. output : nengo.Node The outputs of the network. accumulators : nengo.Thresholding The layer 1 accumulators. thresholding : nengo.Thresholding The layer 2 thresholding ensembles. References ---------- .. [gosmann2017] Jan Gosmann, Aaron R. Voelker, and Chris Eliasmith. "A spiking independent accumulator model for winner-take-all computation." In Proceedings of the 39th Annual Conference of the Cognitive Science Society. London, UK, 2017. Cognitive Science Society. """ def __init__( self, n_neurons, n_ensembles, accum_threshold=0.8, accum_neuron_ratio=0.7, accum_timescale=0.2, feedback_timescale=0.005, accum_synapse=0.1, ff_synapse=0.005, intercept_width=0.15, radius=1.0, **kwargs ): super().__init__(**kwargs) n_accum_neurons = int(accum_neuron_ratio * n_neurons) n_thresholding_neurons = n_neurons - n_accum_neurons bar_beta = 1.0 + radius * feedback_timescale / accum_timescale feedback_tr = ( np.eye(n_ensembles) - bar_beta * (1.0 - np.eye(n_ensembles)) / feedback_timescale ) with self: self.accumulators = Thresholding( n_accum_neurons, n_ensembles, threshold=0.0, intercept_width=intercept_width, radius=radius, ) self.thresholding = Thresholding( n_thresholding_neurons, n_ensembles, threshold=accum_threshold, intercept_width=intercept_width, radius=radius, function=lambda x: x > accum_threshold, ) nengo.Connection( self.accumulators.output, self.accumulators.input, synapse=accum_synapse ) nengo.Connection( self.accumulators.output, self.thresholding.input, synapse=ff_synapse ) nengo.Connection( self.thresholding.output, self.accumulators.input, synapse=accum_synapse, transform=accum_synapse * feedback_tr, ) self.input_reset = nengo.Node(size_in=1) nengo.Connection( self.input_reset, self.accumulators.input, synapse=None, transform=-radius * np.ones((n_ensembles, 1)) / accum_synapse, ) self.input = nengo.Node(size_in=n_ensembles) nengo.Connection( self.input, self.accumulators.input, synapse=None, transform=1.0 / accum_timescale, ) self.output = self.thresholding.output
[docs]class Thresholding(nengo.Network): """Array of thresholding ensembles. All inputs below the threshold will produce an output of 0, whereas inputs above the threshold produce an output of equal value. Parameters ---------- n_neurons : int Number of neurons for each ensemble. n_ensembles : int Number of ensembles. threshold : float The thresholding value. intercept_width : float, optional The `nengo.presets.ThresholdingEnsembles` *intercept_width* parameter. function : function, optional Function to apply to the thresholded values. radius : float, optional The representational radius of the ensembles. **kwargs : dict Keyword arguments passed on to `nengo.Network`. Attributes ---------- input : nengo.Node The inputs to the network. output : nengo.Node The outputs of the network. thresholded : nengo.Node The raw thresholded value (before applying *function* or correcting for the shift produced by the thresholding). """ def __init__( self, n_neurons, n_ensembles, threshold, intercept_width=0.15, function=None, radius=1.0, **kwargs ): super().__init__(**kwargs) with self: with nengo.presets.ThresholdingEnsembles( 0.0, intercept_width, radius=radius ): self.thresholding = nengo.networks.EnsembleArray(n_neurons, n_ensembles) self.bias = nengo.Node(1.0) nengo.Connection( self.bias, self.thresholding.input, transform=-threshold * np.ones((n_ensembles, 1)), ) self.input = self.thresholding.input self.thresholded = self.thresholding.output if function is None: function = lambda x: x function = lambda x, function=function: function(x + threshold) self.output = self.thresholding.add_output("function", function)
[docs]class WTA(Thresholding): """Winner-take-all (WTA) network with lateral inhibition. Parameters ---------- n_neurons : int Number of neurons for each ensemble. n_ensembles : int Number of ensembles. inhibit_scale : float, optional Scaling of the lateral inhibition. inhibit_synapse : Synapse or float, optional Synapse on the recurrent connection for lateral inhibition. **kwargs : dict Keyword arguments passed on to `Thresholding`. Attributes ---------- input : nengo.Node The inputs to the network. output : nengo.Node The outputs of the network. thresholded : nengo.Node The raw thresholded value (before applying *function* or correcting for the shift produced by the thresholding). """ def __init__( self, n_neurons, n_ensembles, inhibit_scale=1.0, inhibit_synapse=0.005, **kwargs ): super().__init__(n_neurons, n_ensembles, **kwargs) with self: nengo.Connection( self.thresholded, self.input, transform=inhibit_scale * (np.eye(n_ensembles) - 1.0), synapse=inhibit_synapse, )