Learning to square the input¶
This demo shows you how to construct a network containing an ensemble which learns how to decode the square of its value.
[1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import nengo
Step 1: Create the Model¶
This network consists of an ensemble A
which represents the input, an ensemble A_squared
which learns to represent the square, and an ensemble error
which represents the error between A_squared
and the actual square.
[2]:
model = nengo.Network()
with model:
# Create the ensemble to represent the input, the input squared (learned),
# and the error
A = nengo.Ensemble(100, dimensions=1)
A_squared = nengo.Ensemble(100, dimensions=1)
error = nengo.Ensemble(100, dimensions=1)
# Connect A and A_squared with a communication channel
conn = nengo.Connection(A, A_squared)
# Apply the PES learning rule to conn
conn.learning_rule_type = nengo.PES(learning_rate=3e-4)
# Provide an error signal to the learning rule
nengo.Connection(error, conn.learning_rule)
# Compute the error signal (error = actual - target)
nengo.Connection(A_squared, error)
# Subtract the target (this would normally come from some external system)
nengo.Connection(A, error, function=lambda x: x ** 2, transform=-1)
Step 2: Provide Input to the Model¶
A single input signal (a step function) will be used to drive the neural activity in ensemble A. An additional node will inhibit the error signal after 15 seconds, to test the learning at the end.
[3]:
with model:
# Create an input node that steps between -1 and 1
input_node = nengo.Node(output=lambda t: int(6 * t / 5) / 3.0 % 2 - 1)
# Connect the input node to ensemble A
nengo.Connection(input_node, A)
# Shut off learning by inhibiting the error population
stop_learning = nengo.Node(output=lambda t: t >= 15)
nengo.Connection(
stop_learning, error.neurons, transform=-20 * np.ones((error.n_neurons, 1))
)
Step 3: Probe the Output¶
Let’s collect output data from each ensemble and output.
[4]:
with model:
input_node_probe = nengo.Probe(input_node)
A_probe = nengo.Probe(A, synapse=0.01)
A_squared_probe = nengo.Probe(A_squared, synapse=0.01)
error_probe = nengo.Probe(error, synapse=0.01)
learn_probe = nengo.Probe(stop_learning, synapse=None)
Step 4: Run the Model¶
[5]:
# Create the simulator
with nengo.Simulator(model) as sim:
sim.run(20)
[6]:
# Plot the input signal
plt.figure(figsize=(9, 9))
plt.subplot(3, 1, 1)
plt.plot(
sim.trange(), sim.data[input_node_probe], label="Input", color="k", linewidth=2.0
)
plt.plot(
sim.trange(),
sim.data[learn_probe],
label="Stop learning?",
color="r",
linewidth=2.0,
)
plt.legend(loc="lower right")
plt.ylim(-1.2, 1.2)
plt.subplot(3, 1, 2)
plt.plot(
sim.trange(), sim.data[input_node_probe] ** 2, label="Squared Input", linewidth=2.0
)
plt.plot(sim.trange(), sim.data[A_squared_probe], label="Decoded Ensemble $A^2$")
plt.legend(loc="lower right")
plt.ylim(-1.2, 1.2)
plt.subplot(3, 1, 3)
plt.plot(
sim.trange(),
sim.data[A_squared_probe] - sim.data[input_node_probe] ** 2,
label="Error",
)
plt.legend(loc="lower right")
plt.tight_layout()
We see that during the first three periods, the decoders quickly adjust to drive the error to zero. When learning is turned off for the fourth period, the error stays closer to zero, demonstrating that the learning has persisted in the connection.