from nengo import Lowpass
from nengo.builder import Signal
from nengo.builder.learning_rules import SimBCM, SimOja, SimVoja, get_post_ens
from nengo.builder.operator import Operator, Reset, DotInc, Copy
from nengo.learning_rules import PES
from nengo.version import version_info as nengo_version
import numpy as np
import tensorflow as tf
from nengo_dl.builder import Builder, OpBuilder, NengoBuilder
[docs]@Builder.register(SimBCM)
class SimBCMBuilder(OpBuilder):
"""Build a group of :class:`~nengo:nengo.builder.learning_rules.SimBCM`
operators."""
def __init__(self, ops, signals):
super(SimBCMBuilder, self).__init__(ops, signals)
self.post_data = signals.combine([op.post_filtered for op in ops])
self.theta_data = signals.combine([op.theta for op in ops])
self.pre_data = signals.combine(
[op.pre_filtered for op in ops
for _ in range(op.post_filtered.shape[0])], load_indices=False)
self.pre_data = self.pre_data.reshape((self.post_data.shape[0],
ops[0].pre_filtered.shape[0]))
self.pre_data.load_indices(constant=signals.constant)
self.learning_rate = signals.op_constant(
ops, [op.post_filtered.shape[0] for op in ops], "learning_rate",
signals.dtype, ndims=3)
self.output_data = signals.combine([op.delta for op in ops])
[docs] def build_step(self, signals):
pre = signals.gather(self.pre_data)
post = signals.gather(self.post_data)
theta = signals.gather(self.theta_data)
post = self.learning_rate * signals.dt * post * (post - theta)
post = tf.expand_dims(post, 1)
signals.scatter(self.output_data, post * pre)
[docs]@Builder.register(SimOja)
class SimOjaBuilder(OpBuilder):
"""Build a group of :class:`~nengo:nengo.builder.learning_rules.SimOja`
operators."""
def __init__(self, ops, signals):
super(SimOjaBuilder, self).__init__(ops, signals)
self.post_data = signals.combine([op.post_filtered for op in ops])
self.pre_data = signals.combine(
[op.pre_filtered for op in ops
for _ in range(op.post_filtered.shape[0])], load_indices=False)
self.pre_data = self.pre_data.reshape((self.post_data.shape[0],
ops[0].pre_filtered.shape[0]))
self.pre_data.load_indices(constant=signals.constant)
self.weights_data = signals.combine([op.weights for op in ops])
self.output_data = signals.combine([op.delta for op in ops])
self.learning_rate = signals.op_constant(
ops, [op.post_filtered.shape[0] for op in ops], "learning_rate",
signals.dtype, ndims=3)
self.beta = signals.op_constant(
ops, [op.post_filtered.shape[0] for op in ops], "beta",
signals.dtype, ndims=3)
[docs] def build_step(self, signals):
pre = signals.gather(self.pre_data)
post = signals.gather(self.post_data)
weights = signals.gather(self.weights_data)
post = tf.expand_dims(post, 1)
alpha = self.learning_rate * signals.dt
update = alpha * post ** 2
update *= -self.beta * weights
update += alpha * post * pre
signals.scatter(self.output_data, update)
[docs]@Builder.register(SimVoja)
class SimVojaBuilder(OpBuilder):
"""Build a group of :class:`~nengo:nengo.builder.learning_rules.SimVoja`
operators."""
def __init__(self, ops, signals):
super(SimVojaBuilder, self).__init__(ops, signals)
self.post_data = signals.combine([op.post_filtered for op in ops])
self.pre_data = signals.combine(
[op.pre_decoded for op in ops
for _ in range(op.post_filtered.shape[0])], load_indices=False)
self.pre_data = self.pre_data.reshape((self.post_data.shape[0],
ops[0].pre_decoded.shape[0]))
self.pre_data.load_indices(constant=signals.constant)
self.learning_data = signals.combine(
[op.learning_signal for op in ops
for _ in range(op.post_filtered.shape[0])])
self.encoder_data = signals.combine([op.scaled_encoders for op in ops])
self.output_data = signals.combine([op.delta for op in ops])
self.scale = signals.constant(
np.concatenate([op.scale[:, None, None] for op in ops], axis=0),
dtype=signals.dtype)
self.learning_rate = signals.op_constant(
ops, [op.post_filtered.shape[0] for op in ops], "learning_rate",
signals.dtype)
[docs] def build_step(self, signals):
pre = signals.gather(self.pre_data)
post = signals.gather(self.post_data)
learning_signal = signals.gather(self.learning_data)
scaled_encoders = signals.gather(self.encoder_data)
alpha = tf.expand_dims(
self.learning_rate * signals.dt * learning_signal, 1)
post = tf.expand_dims(post, 1)
update = alpha * (self.scale * post * pre - post * scaled_encoders)
signals.scatter(self.output_data, update)
[docs]class SimPES(Operator):
r"""
Calculate connection weight change according to the PES rule.
Implements the PES learning rule of the form
.. math:: \Delta \omega_{ij} = \frac{\kappa}{n} e_j a_i
where
* :math:`\kappa` is a scalar learning rate,
* :math:`n` is the number of presynaptic neurons
* :math:`e_j` is the error for the jth output dimension, and
* :math:`a_i` is the activity of a presynaptic neuron.
Parameters
----------
pre_filtered : Signal
The presynaptic activity, :math:`a_i`.
error : Signal
The error signal, :math:`e_j`.
delta : Signal
The synaptic weight change to be applied, :math:`\Delta \omega_{ij}`.
learning_rate : float
The scalar learning rate, :math:`\kappa`.
tag : str, optional (Default: None)
A label associated with the operator, for debugging purposes.
Attributes
----------
pre_filtered : Signal
The presynaptic activity, :math:`a_i`.
error : Signal
The error signal, :math:`e_j`.
delta : Signal
The synaptic weight change to be applied, :math:`\Delta \omega_{ij}`.
learning_rate : float
The scalar learning rate, :math:`\kappa`.
tag : str, optional (Default: None)
A label associated with the operator, for debugging purposes.
Notes
-----
1. sets ``[delta]``
2. incs ``[]``
3. reads ``[pre_filtered, error]``
4. updates ``[]``
"""
def __init__(self, pre_filtered, error, delta, learning_rate, tag=None):
super(SimPES, self).__init__(tag=tag)
self.pre_filtered = pre_filtered
self.error = error
self.delta = delta
self.learning_rate = learning_rate
# TODO: change this to an update when (if) we make that change in nengo
self.sets = [delta]
self.incs = []
self.reads = [pre_filtered, error]
self.updates = []
def _descstr(self):
return 'pre=%s, error=%s -> %s' % (
self.pre_filtered, self.error, self.delta)
[docs]@NengoBuilder.register(PES)
def build_pes(model, pes, rule):
"""
Builds a `.PES` object into a model.
Parameters
----------
model : Model
The model to build into.
pes : PES
Learning rule type to build.
rule : LearningRule
The learning rule object corresponding to the neuron type.
Notes
-----
Does not modify ``model.params[]`` and can therefore be called
more than once with the same `.PES` instance.
"""
conn = rule.connection
# Create input error signal
error = Signal(np.zeros(rule.size_in), name="PES:error")
model.add_op(Reset(error))
model.sig[rule]['in'] = error # error connection will attach here
if nengo_version < (2, 7, 1):
acts = model.build(
Lowpass(pes.pre_tau), model.sig[conn.pre_obj]["out"])
else:
acts = model.build(pes.pre_synapse, model.sig[conn.pre_obj]["out"])
if not conn.is_decoded:
# multiply error by post encoders to get a per-neuron error
post = get_post_ens(conn)
encoders = model.sig[post]["encoders"]
if conn.post_obj is not conn.post:
# in order to avoid slicing encoders along an axis > 0, we pad
# `error` out to the full base dimensionality and then do the
# dotinc with the full encoder matrix
padded_error = Signal(np.zeros(encoders.shape[1]))
model.add_op(Copy(error, padded_error,
dst_slice=conn.post_slice))
else:
padded_error = error
# error = dot(encoders, error)
local_error = Signal(np.zeros(post.n_neurons), name="PES:encoded")
model.add_op(Reset(local_error))
model.add_op(DotInc(encoders, padded_error, local_error,
tag="PES:encode"))
else:
local_error = error
model.operators.append(SimPES(acts, local_error, model.sig[rule]["delta"],
pes.learning_rate))
# expose these for probes
model.sig[rule]["error"] = error
model.sig[rule]["activities"] = acts
# remove 'correction' from probeable attributes
PES.probeable = ("error", "activities", "delta")
[docs]@Builder.register(SimPES)
class SimPESBuilder(OpBuilder):
"""Build a group of :class:`.SimPES` operators."""
def __init__(self, ops, signals):
super(SimPESBuilder, self).__init__(ops, signals)
self.error_data = signals.combine(
[op.error for op in ops], load_indices=False)
self.error_data = self.error_data.reshape(
(len(ops), ops[0].error.shape[0], 1))
self.error_data.load_indices(constant=signals.constant)
self.pre_data = signals.combine(
[op.pre_filtered for op in ops], load_indices=False)
self.pre_data = self.pre_data.reshape(
(len(ops), 1, ops[0].pre_filtered.shape[0]))
self.pre_data.load_indices(constant=signals.constant)
self.alpha = signals.op_constant(
ops, [1 for _ in ops], "learning_rate", signals.dtype, ndims=4) * (
-signals.dt_val / ops[0].pre_filtered.shape[0])
self.output_data = signals.combine([op.delta for op in ops])
[docs] def build_step(self, signals):
pre_filtered = signals.gather(self.pre_data)
error = signals.gather(self.error_data)
error *= self.alpha
update = error * pre_filtered
signals.scatter(self.output_data, update)