- Overview
- Installation
- Configuration
- Example models
- API reference
- Tips and tricks
- Hardware setup
PES learning¶
In this example, we will use the PES learning rule to learn a communication channel.
[1]:
import matplotlib.pyplot as plt
%matplotlib inline
import nengo
from nengo.processes import WhiteSignal
import nengo_loihi
nengo_loihi.set_defaults()
/home/travis/build/nengo/nengo-loihi/nengo_loihi/version.py:23: UserWarning: This version of `nengo_loihi` has not been tested with your `nengo` version (3.0.1.dev0). The latest fully supported version is 3.0.0
"supported version is %s" % (nengo.__version__, latest_nengo_version)
/home/travis/virtualenv/python3.6.3/lib/python3.6/site-packages/nengo_dl/version.py:42: UserWarning: This version of `nengo_dl` has not been tested with your `nengo` version (3.0.1.dev0). The latest fully supported version is 3.0.0.
% ((nengo.version.version,) + latest_nengo_version)
Creating the network in Nengo¶
When creating a nengo.Connection
, you can specify a learning_rule_type
. When using the nengo.PES
learning rule type, the connection is modified such that it can accept input in its learning_rule
attribute. That input is interpreted as an error signal that the PES rule attempts to minimize over time by adjusting decoders or connection weights.
[2]:
with nengo.Network(label="PES learning") as model:
# Randomly varying input signal
stim = nengo.Node(WhiteSignal(60, high=5), size_out=1)
# Connect pre to the input signal
pre = nengo.Ensemble(100, dimensions=1)
nengo.Connection(stim, pre)
post = nengo.Ensemble(100, dimensions=1)
# When connecting pre to post,
# create the connection such that initially it will
# always output 0. Usually this results in connection
# weights that are also all 0.
conn = nengo.Connection(pre, post,
function=lambda x: [0],
learning_rule_type=nengo.PES(learning_rate=2e-4))
# Calculate the error signal with another ensemble
error = nengo.Ensemble(100, dimensions=1)
# Error = actual - target = post - pre
nengo.Connection(post, error)
nengo.Connection(pre, error, transform=-1)
# Connect the error into the learning rule
nengo.Connection(error, conn.learning_rule)
stim_p = nengo.Probe(stim)
pre_p = nengo.Probe(pre, synapse=0.01)
post_p = nengo.Probe(post, synapse=0.01)
error_p = nengo.Probe(error, synapse=0.01)
Running the network in Nengo¶
We can use Nengo to see the desired model output.
[3]:
with nengo.Simulator(model) as sim:
sim.run(10)
t = sim.trange()
0%
0%
[4]:
def plot_decoded(t, data):
plt.figure(figsize=(12, 12))
plt.subplot(2, 1, 1)
plt.plot(t, data[stim_p].T[0], label='Input')
plt.plot(t, data[pre_p].T[0], label='pre')
plt.plot(t, data[post_p].T[0], label='post')
plt.ylabel("Decoded output")
plt.legend(loc='best')
plt.subplot(2, 1, 2)
plt.plot(t, data[error_p])
plt.ylim(-1, 1)
plt.legend(("Error signal",), loc='best')
plot_decoded(t, sim.data)
While post
initially only represents 0, over time it comes to more closely track the value represented in pre
. The error signal also decreases gradually over time as the decoded values in pre
and post
get closer and closer.
Running the network with Nengo Loihi¶
[5]:
with nengo_loihi.Simulator(model) as sim:
sim.run(10)
t = sim.trange()
[6]:
plot_decoded(t, sim.data)