Source code for nengo_extras.learning_rules

import warnings

import numpy as np

from nengo.builder import Builder, Signal
from nengo.builder.connection import get_eval_points, solve_for_decoders
from nengo.builder.operator import (
    DotInc, ElementwiseInc, Operator, Reset, SimPyFunc)
from nengo.exceptions import ValidationError
from nengo.learning_rules import LearningRuleType
from nengo.params import EnumParam, FunctionParam, NumberParam
from nengo.synapses import Lowpass


[docs]class AML(LearningRuleType): r"""Association matrix learning rule (AML). Enables one-shot learning without catastrophic forgetting of outer product association matrices. The cue is provided by the pre-synaptic ensemble. The error signal is split up: ``error[0]`` provides a scaling factor to the learning rate. ``error[1]`` provides a decay rate (i.e., weights are multiplied with this value in every time step), ``error[2:]`` provides the target vector. The update is given by:: decoders[...] *= error[1] # decay decoders[...] += alpha * error[0] * error[2:, None] * np.dot( pre, base_decoders.T) where *alpha* is the learning rate adjusted for *dt* and *base_decoders* is the decoder matrix for decoding the identity from the pre-ensemble. Parameters ---------- d : int Dimensionality of input and output vectors (error signal will be *d+2*). learning_rate : float, optional Learning rate (increase of dot product similarity per second). """ error_type = 'decoded' modifies = 'decoders' def __init__(self, d, learning_rate=1.): super(AML, self).__init__(learning_rate, size_in=d + 2)
class SimAML(Operator): def __init__(self, learning_rate, base_decoders, pre, error, decoders, tag=None): super(SimAML, self).__init__(tag=tag) self.learning_rate = learning_rate self.base_decoders = base_decoders self.sets = [] self.incs = [] self.reads = [pre, error] self.updates = [decoders] def make_step(self, signals, dt, rng): base_decoders = self.base_decoders pre = signals[self.pre] error = signals[self.error] decoders = signals[self.decoders] alpha = self.learning_rate * dt def step_assoc_learning(): scale = error[0] decay = error[1] target = error[2:] decoders[...] *= decay decoders[...] += alpha * scale * target[:, None] * np.dot( pre, base_decoders.T) return step_assoc_learning @property def pre(self): return self.reads[0] @property def error(self): return self.reads[1] @property def decoders(self): return self.updates[0] @Builder.register(AML) def build_aml(model, aml, rule): conn = rule.connection rng = np.random.RandomState(model.seeds[conn]) error = Signal(np.zeros(rule.size_in), name="aml:error") model.add_op(Reset(error)) model.sig[rule]['in'] = error pre = model.sig[conn.pre_obj]['in'] decoders = model.sig[conn]['weights'] encoders = model.params[conn.pre_obj].encoders gain = model.params[conn.pre_obj].gain bias = model.params[conn.pre_obj].bias eval_points = get_eval_points(model, conn, rng) targets = eval_points x = np.dot(eval_points, encoders.T) wrapped_solver = (model.decoder_cache.wrap_solver(solve_for_decoders) if model.seeded[conn] else solve_for_decoders) base_decoders, _ = wrapped_solver(conn, gain, bias, x, targets, rng=rng) model.add_op(SimAML( aml.learning_rate, base_decoders, pre, error, decoders)) class DeltaRuleFunctionParam(FunctionParam): function_test_size = 8 # arbitrary size to test function def function_args(self, instance, function): return (np.zeros(self.function_test_size),) def coerce(self, instance, function): function_info = super(DeltaRuleFunctionParam, self).coerce( instance, function) function, size = function_info if function is not None and size != self.function_test_size: raise ValidationError( "Function '%s' input and output sizes must be equal" % function, attr=self.name, obj=instance) return function_info
[docs]class DeltaRule(LearningRuleType): r"""Implementation of the Delta rule. By default, this implementation pretends the neurons are linear, and thus does not require the derivative of the postsynaptic neuron activation function. The derivative function, or a surrogate function, for the postsynaptic neurons can be provided in ``post_fn``. The update is given by: \delta W_ij = \eta a_j e_i f(u_i) where ``e_i`` is the input error in the postsynaptic neuron space, ``a_j`` is the output activity for presynaptic neuron j, ``u_i`` is the input for postsynaptic neuron i, and ``f`` is a provided function. Parameters ---------- learning_rate : float A scalar indicating the rate at which weights will be adjusted. pre_tau : float Filter constant on the presynaptic output ``a_j``. post_fn : callable Function ``f`` to apply to the postsynaptic inputs ``u_i``. The default of ``None`` means the ``f(u_i)`` term is omitted. post_tau : float Filter constant on the postsynaptic input ``u_i``. This defaults to ``None`` because these should typically be filtered by the connection. """ modifies = 'weights' probeable = ('delta', 'in', 'error', 'correction', 'pre', 'post') pre_tau = NumberParam('pre_tau', low=0, low_open=True) post_tau = NumberParam('post_tau', low=0, low_open=True, optional=True) post_fn = DeltaRuleFunctionParam('post_fn', optional=True) post_target = EnumParam('post_target', values=('in', 'out')) def __init__(self, learning_rate=1e-4, pre_tau=0.005, post_fn=None, post_tau=None, post_target='in'): if learning_rate >= 1.0: warnings.warn("This learning rate is very high, and can result " "in floating point errors from too much current.") self.pre_tau = pre_tau self.post_tau = post_tau self.post_fn = post_fn self.post_target = post_target super(DeltaRule, self).__init__(learning_rate, size_in='post') @property def _argreprs(self): args = [] if self.learning_rate != 1e-4: args.append("learning_rate=%g" % self.learning_rate) if self.pre_tau != 0.005: args.append("pre_tau=%f" % self.pre_tau) if self.post_fn is not None: args.append("post_fn=%s" % self.post_fn.function) if self.post_tau is not None: args.append("post_tau=%f" % self.post_tau) if self.post_target != 'in': args.append("post_target=%s" % self.post_target) return args
@Builder.register(DeltaRule) def build_delta_rule(model, delta_rule, rule): conn = rule.connection # Create input error signal error = Signal(np.zeros(rule.size_in), name="DeltaRule:error") model.add_op(Reset(error)) model.sig[rule]['in'] = error # error connection will attach here # Multiply by post_fn output if necessary post_fn = delta_rule.post_fn.function post_tau = delta_rule.post_tau post_target = delta_rule.post_target if post_fn is not None: post_sig = model.sig[conn.post_obj][post_target] post_synapse = Lowpass(post_tau) if post_tau is not None else None post_input = (post_sig if post_synapse is None else model.build(post_synapse, post_sig)) post = Signal(np.zeros(post_input.shape), name="DeltaRule:post") model.add_op(SimPyFunc(post, post_fn, t=None, x=post_input, tag="DeltaRule:post_fn")) model.sig[rule]['post'] = post error0 = error error = Signal(np.zeros(rule.size_in), name="DeltaRule:post_error") model.add_op(Reset(error)) model.add_op(ElementwiseInc(error0, post, error)) # Compute: correction = -learning_rate * dt * error correction = Signal(np.zeros(error.shape), name="DeltaRule:correction") model.add_op(Reset(correction)) lr_sig = Signal(-delta_rule.learning_rate * model.dt, name="DeltaRule:learning_rate") model.add_op(DotInc(lr_sig, error, correction, tag="DeltaRule:correct")) # delta_ij = correction_i * pre_j pre_synapse = Lowpass(delta_rule.pre_tau) pre = model.build(pre_synapse, model.sig[conn.pre_obj]['out']) model.add_op(Reset(model.sig[rule]['delta'])) model.add_op(ElementwiseInc( correction.column(), pre.row(), model.sig[rule]['delta'], tag="DeltaRule:Inc Delta")) # expose these for probes model.sig[rule]['error'] = error model.sig[rule]['correction'] = correction model.sig[rule]['pre'] = pre