"""
Build classes for Nengo learning rule operators.
"""
from nengo.builder import Signal
from nengo.builder.learning_rules import (
SimBCM,
SimOja,
SimVoja,
get_post_ens,
build_or_passthrough,
)
from nengo.builder.operator import Reset, DotInc, Copy
from nengo.learning_rules import PES
import numpy as np
import tensorflow as tf
from nengo_dl.builder import Builder, OpBuilder, NengoBuilder
from nengo_dl.compat import SimPES
[docs]@Builder.register(SimBCM)
class SimBCMBuilder(OpBuilder):
"""Build a group of `~nengo.builder.learning_rules.SimBCM`
operators."""
def __init__(self, ops, signals, config):
super(SimBCMBuilder, self).__init__(ops, signals, config)
self.post_data = signals.combine([op.post_filtered for op in ops])
self.theta_data = signals.combine([op.theta for op in ops])
self.pre_data = signals.combine(
[op.pre_filtered for op in ops for _ in range(op.post_filtered.shape[0])]
)
self.pre_data = self.pre_data.reshape(
(self.post_data.shape[0], ops[0].pre_filtered.shape[0])
)
self.learning_rate = signals.op_constant(
ops,
[op.post_filtered.shape[0] for op in ops],
"learning_rate",
signals.dtype,
ndims=3,
)
self.output_data = signals.combine([op.delta for op in ops])
[docs] def build_step(self, signals):
pre = signals.gather(self.pre_data)
post = signals.gather(self.post_data)
theta = signals.gather(self.theta_data)
post = self.learning_rate * signals.dt * post * (post - theta)
post = tf.expand_dims(post, 1)
signals.scatter(self.output_data, post * pre)
[docs] @staticmethod
def mergeable(x, y):
# pre inputs must have the same dimensionality so that we can broadcast
# them when computing the outer product
return x.pre_filtered.shape[0] == y.pre_filtered.shape[0]
[docs]@Builder.register(SimOja)
class SimOjaBuilder(OpBuilder):
"""Build a group of `~nengo.builder.learning_rules.SimOja`
operators."""
def __init__(self, ops, signals, config):
super(SimOjaBuilder, self).__init__(ops, signals, config)
self.post_data = signals.combine([op.post_filtered for op in ops])
self.pre_data = signals.combine(
[op.pre_filtered for op in ops for _ in range(op.post_filtered.shape[0])]
)
self.pre_data = self.pre_data.reshape(
(self.post_data.shape[0], ops[0].pre_filtered.shape[0])
)
self.weights_data = signals.combine([op.weights for op in ops])
self.output_data = signals.combine([op.delta for op in ops])
self.learning_rate = signals.op_constant(
ops,
[op.post_filtered.shape[0] for op in ops],
"learning_rate",
signals.dtype,
ndims=3,
)
self.beta = signals.op_constant(
ops,
[op.post_filtered.shape[0] for op in ops],
"beta",
signals.dtype,
ndims=3,
)
[docs] def build_step(self, signals):
pre = signals.gather(self.pre_data)
post = signals.gather(self.post_data)
weights = signals.gather(self.weights_data)
post = tf.expand_dims(post, 1)
alpha = self.learning_rate * signals.dt
update = alpha * post ** 2
update *= -self.beta * weights
update += alpha * post * pre
signals.scatter(self.output_data, update)
[docs] @staticmethod
def mergeable(x, y):
# pre inputs must have the same dimensionality so that we can broadcast
# them when computing the outer product
return x.pre_filtered.shape[0] == y.pre_filtered.shape[0]
[docs]@Builder.register(SimVoja)
class SimVojaBuilder(OpBuilder):
"""Build a group of `~nengo.builder.learning_rules.SimVoja`
operators."""
def __init__(self, ops, signals, config):
super(SimVojaBuilder, self).__init__(ops, signals, config)
self.post_data = signals.combine([op.post_filtered for op in ops])
self.pre_data = signals.combine(
[op.pre_decoded for op in ops for _ in range(op.post_filtered.shape[0])]
)
self.pre_data = self.pre_data.reshape(
(self.post_data.shape[0], ops[0].pre_decoded.shape[0])
)
self.learning_data = signals.combine(
[op.learning_signal for op in ops for _ in range(op.post_filtered.shape[0])]
)
self.encoder_data = signals.combine([op.scaled_encoders for op in ops])
self.output_data = signals.combine([op.delta for op in ops])
self.scale = signals.constant(
np.concatenate([op.scale[:, None, None] for op in ops], axis=0),
dtype=signals.dtype,
)
self.learning_rate = signals.op_constant(
ops,
[op.post_filtered.shape[0] for op in ops],
"learning_rate",
signals.dtype,
)
[docs] def build_step(self, signals):
pre = signals.gather(self.pre_data)
post = signals.gather(self.post_data)
learning_signal = signals.gather(self.learning_data)
scaled_encoders = signals.gather(self.encoder_data)
alpha = tf.expand_dims(self.learning_rate * signals.dt * learning_signal, 1)
post = tf.expand_dims(post, 1)
update = alpha * (self.scale * post * pre - post * scaled_encoders)
signals.scatter(self.output_data, update)
[docs] @staticmethod
def mergeable(x, y):
# pre inputs must have the same dimensionality so that we can broadcast
# them when computing the outer product
return x.pre_decoded.shape[0] == y.pre_decoded.shape[0]
[docs]@NengoBuilder.register(PES)
def build_pes(model, pes, rule):
"""
Builds a `nengo.PES` object into a model.
Parameters
----------
model : Model
The model to build into.
pes : PES
Learning rule type to build.
rule : LearningRule
The learning rule object corresponding to the neuron type.
Notes
-----
Does not modify ``model.params[]`` and can therefore be called
more than once with the same `nengo.PES` instance.
"""
conn = rule.connection
# Create input error signal
error = Signal(np.zeros(rule.size_in), name="PES:error")
model.add_op(Reset(error))
model.sig[rule]["in"] = error # error connection will attach here
acts = build_or_passthrough(model, pes.pre_synapse, model.sig[conn.pre_obj]["out"])
if not conn.is_decoded:
# multiply error by post encoders to get a per-neuron error
post = get_post_ens(conn)
encoders = model.sig[post]["encoders"]
if conn.post_obj is not conn.post:
# in order to avoid slicing encoders along an axis > 0, we pad
# `error` out to the full base dimensionality and then do the
# dotinc with the full encoder matrix
padded_error = Signal(np.zeros(encoders.shape[1]))
model.add_op(Copy(error, padded_error, dst_slice=conn.post_slice))
else:
padded_error = error
# error = dot(encoders, error)
local_error = Signal(np.zeros(post.n_neurons), name="PES:encoded")
model.add_op(Reset(local_error))
model.add_op(DotInc(encoders, padded_error, local_error, tag="PES:encode"))
else:
local_error = error
model.operators.append(
SimPES(acts, local_error, model.sig[rule]["delta"], pes.learning_rate)
)
# expose these for probes
model.sig[rule]["error"] = error
model.sig[rule]["activities"] = acts
[docs]@Builder.register(SimPES)
class SimPESBuilder(OpBuilder):
"""Build a group of `~nengo.builder.learning_rules.SimPES` operators."""
def __init__(self, ops, signals, config):
super(SimPESBuilder, self).__init__(ops, signals, config)
self.error_data = signals.combine([op.error for op in ops])
self.error_data = self.error_data.reshape((len(ops), ops[0].error.shape[0], 1))
self.pre_data = signals.combine([op.pre_filtered for op in ops])
self.pre_data = self.pre_data.reshape(
(len(ops), 1, ops[0].pre_filtered.shape[0])
)
self.alpha = signals.op_constant(
ops, [1 for _ in ops], "learning_rate", signals.dtype, ndims=4
) * (-signals.dt_val / ops[0].pre_filtered.shape[0])
assert all(op.encoders is None for op in ops)
self.output_data = signals.combine([op.delta for op in ops])
[docs] def build_step(self, signals):
pre_filtered = signals.gather(self.pre_data)
error = signals.gather(self.error_data)
error *= self.alpha
update = error * pre_filtered
signals.scatter(self.output_data, update)
[docs] @staticmethod
def mergeable(x, y):
# pre inputs must have the same dimensionality so that we can broadcast
# them when computing the outer product.
# the error signals also have to have the same shape.
return (
x.pre_filtered.shape[0] == y.pre_filtered.shape[0]
and x.error.shape[0] == y.error.shape[0]
)