Reusing connection weights¶
This example shows how to reuse weights that have been learned with online learning for use in another networks. It assumes that you’ve already gone through and understood the heteroassociative memory learning example.
First create the network as shown in the aforementioned heteroassociative memory example.
[1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import nengo
[2]:
num_items = 5
d_key = 2
d_value = 4
SEED = 7
rng = np.random.RandomState(seed=SEED)
keys = nengo.dists.UniformHypersphere(surface=True).sample(num_items, d_key, rng=rng)
values = nengo.dists.UniformHypersphere(surface=False).sample(
num_items, d_value, rng=rng
)
intercept = (np.dot(keys, keys.T) - np.eye(num_items)).flatten().max()
def cycle_array(x, cycle_period, cycle_dt=0.001):
"""Cycles through the elements"""
i_every = int(round(cycle_period / cycle_dt))
if i_every != cycle_period / cycle_dt:
raise ValueError(f"dt ({cycle_dt}) does not divide period ({cycle_period})")
def f(t):
i = int(round((t - cycle_dt) / cycle_dt)) # t starts at dt
idx = (i // i_every) % len(x)
return x[idx]
return f
# Model constants
n_neurons = 200
dt = 0.001
period = 0.3
T = period * num_items
sample_every = 0.01
with nengo.Network() as train_model:
# Create the inputs/outputs
stim_keys = nengo.Node(cycle_array(keys, period, dt))
stim_values = nengo.Node(cycle_array(values, period, dt))
# Turn learning permanently on
learning = nengo.Node([0])
recall = nengo.Node(size_in=d_value)
# Create the memory with a seed, so we can create the same ensemble again
# in the new network
memory = nengo.Ensemble(
n_neurons, d_key, intercepts=[intercept] * n_neurons, seed=SEED
)
# Learn the encoders/keys
voja = nengo.Voja(post_synapse=None, learning_rate=5e-2)
conn_in = nengo.Connection(stim_keys, memory, synapse=None, learning_rule_type=voja)
nengo.Connection(learning, conn_in.learning_rule, synapse=None)
# Learn the decoders/values, initialized to a null function
conn_out = nengo.Connection(
memory,
recall,
learning_rule_type=nengo.PES(1e-3),
function=lambda x: np.zeros(d_value),
)
# Create the error population
error = nengo.Ensemble(n_neurons, d_value)
nengo.Connection(
learning, error.neurons, transform=[[10.0]] * n_neurons, synapse=None
)
# Calculate the error and use it to drive the PES rule
nengo.Connection(stim_values, error, transform=-1, synapse=None)
nengo.Connection(recall, error, synapse=None)
nengo.Connection(error, conn_out.learning_rule)
Instead of probing the usual outputs of the network, we’re going to probe the weights so we can transfer them to our new network. We could still probe the input and outputs to verify the network is functioning, but for now, we’re just going to assume it works.
[3]:
with train_model:
# Setup probes to save the weights
p_dec = nengo.Probe(conn_out, "weights", sample_every=sample_every)
p_enc = nengo.Probe(memory, "scaled_encoders", sample_every=sample_every)
[4]:
# run the model and retrieve the encoders and decoders
with nengo.Simulator(train_model, dt=dt) as train_sim:
train_sim.run(T)
enc = train_sim.data[p_enc][-1]
dec = train_sim.data[p_dec][-1]
We’ll now insert the encoders and decoders we gathered into our new network and verify that our network works the same as the old one. One important thing to note is that we seed
parameter of the memory
the same way as the last network to make sure we’re dealing with the same neurons.
[5]:
with nengo.Network() as test_model:
# Create the inputs/outputs
stim_keys = nengo.Node(cycle_array(keys, period, dt))
stim_values = nengo.Node(cycle_array(values, period, dt))
# Turn learning off to show that our network still works
learning = nengo.Node([-1])
recall = nengo.Node(size_in=d_value)
# Create the memory with the new encoders
memory = nengo.Ensemble(
n_neurons,
d_key,
intercepts=[intercept] * n_neurons,
encoders=enc,
n_eval_points=0,
seed=SEED,
)
nengo.Connection(stim_keys, memory, synapse=None)
# Create the connection out with the new decoders
conn_out = nengo.Connection(memory.neurons, recall, transform=dec)
# Setup probes
p_val = nengo.Probe(stim_values, synapse=0.005)
p_recall = nengo.Probe(recall, synapse=0.005)
[6]:
# run the network and plot the results for verification
with nengo.Simulator(test_model, dt=dt) as test_sim:
test_sim.run(T)
[7]:
plt.plot(test_sim.trange(), test_sim.data[p_val])
plt.title("Expected")
plt.xlabel("Time (s)")
plt.figure()
plt.plot(test_sim.trange(), test_sim.data[p_recall])
plt.title("Recalled")
plt.xlabel("Time (s)")
plt.show()
The values output by our memory match our expected values. Our connection weight transfer worked!