In this example, we will use the PES learning rule to learn a communication channel.
[1]:
import matplotlib.pyplot as plt
%matplotlib inline
import nengo
from nengo.processes import WhiteSignal
import nengo_loihi
nengo_loihi.set_defaults()
/home/travis/build/nengo/nengo-loihi/nengo_loihi/version.py:23: UserWarning: This version of `nengo_loihi` has not been tested with your `nengo` version (3.0.0.dev0). The latest fully supported version is 2.8.0
nengo.__version__, latest_nengo_version))
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/version.py:32: UserWarning: This version of `nengo_dl` has not been tested with your `nengo` version (3.0.0.dev0). The latest fully supported version is 2.8.0.
((nengo.version.version,) + latest_nengo_version))
WARNING: Logging before flag parsing goes to stderr.
W0624 00:12:55.114527 139838612113216 deprecation_wrapper.py:119] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/__init__.py:55: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.
W0624 00:12:55.115711 139838612113216 deprecation_wrapper.py:119] From /home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo_dl/__init__.py:55: The name tf.logging.WARN is deprecated. Please use tf.compat.v1.logging.WARN instead.
When creating a nengo.Connection
, you can specify a learning_rule_type
. When using the nengo.PES
learning rule type, the connection is modified such that it can accept input in its learning_rule
attribute. That input is interpreted as an error signal that the PES rule attempts to minimize over time by adjusting decoders or connection weights.
[2]:
with nengo.Network(label="PES learning") as model:
# Randomly varying input signal
stim = nengo.Node(WhiteSignal(60, high=5), size_out=1)
# Connect pre to the input signal
pre = nengo.Ensemble(100, dimensions=1)
nengo.Connection(stim, pre)
post = nengo.Ensemble(100, dimensions=1)
# When connecting pre to post,
# create the connection such that initially it will
# always output 0. Usually this results in connection
# weights that are also all 0.
conn = nengo.Connection(pre, post,
function=lambda x: [0],
learning_rule_type=nengo.PES(learning_rate=2e-4))
# Calculate the error signal with another ensemble
error = nengo.Ensemble(100, dimensions=1)
# Error = actual - target = post - pre
nengo.Connection(post, error)
nengo.Connection(pre, error, transform=-1)
# Connect the error into the learning rule
nengo.Connection(error, conn.learning_rule)
stim_p = nengo.Probe(stim)
pre_p = nengo.Probe(pre, synapse=0.01)
post_p = nengo.Probe(post, synapse=0.01)
error_p = nengo.Probe(error, synapse=0.01)
We can use Nengo to see the desired model output.
[3]:
with nengo.Simulator(model) as sim:
sim.run(10)
t = sim.trange()
/home/travis/virtualenv/python3.5.2/lib/python3.5/site-packages/nengo/builder/optimizer.py:640: UserWarning: Skipping some optimization steps because SciPy is not installed. Installing SciPy may result in faster simulations.
warnings.warn("Skipping some optimization steps because SciPy is "
[4]:
def plot_decoded(t, data):
plt.figure(figsize=(12, 12))
plt.subplot(2, 1, 1)
plt.plot(t, data[stim_p].T[0], label='Input')
plt.plot(t, data[pre_p].T[0], label='pre')
plt.plot(t, data[post_p].T[0], label='post')
plt.ylabel("Decoded output")
plt.legend(loc='best')
plt.subplot(2, 1, 2)
plt.plot(t, data[error_p])
plt.ylim(-1, 1)
plt.legend(("Error signal",), loc='best')
plot_decoded(t, sim.data)
While post
initially only represents 0, over time it comes to more closely track the value represented in pre
. The error signal also decreases gradually over time as the decoded values in pre
and post
get closer and closer.
[5]:
with nengo_loihi.Simulator(model) as sim:
sim.run(10)
t = sim.trange()
[6]:
plot_decoded(t, sim.data)