Inserting a TensorFlow network into a Nengo model¶
TensorFlow comes with a wide range of pre-defined deep learning models, which we might want to incorporate into a Nengo model. For example, suppose we are building a biological reinforcement learning model, but we’d like the inputs to our model to be natural images rather than artificial vectors. We could load a vision network from TensorFlow, insert it into our model using NengoDL, and then build the rest of our model using normal Nengo syntax.
In this example we’ll show how to use TensorNodes to insert a pre-trained TensorFlow model (Inception-v1) into Nengo.
In [1]:
%matplotlib inline
import sys
import os
from urllib.request import urlopen
import io
import shutil
import stat
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf
import tensorflow.contrib.slim as slim;
import nengo
import nengo_dl
TensorFlow provides a number of pre-defined models in the tensorflow/models repository. These are not included when you install TensorFlow, so we need to separately clone that repository and import the components we need.
In [2]:
!git clone https://github.com/tensorflow/models
sys.path.append(os.path.join(".", "models", "research", "slim"))
from datasets import dataset_utils, imagenet
from nets import inception
from preprocessing import inception_preprocessing
Cloning into 'models'...
Checking out files: 19% (460/2329)
Checking out files: 20% (466/2329)
Checking out files: 21% (490/2329)
Checking out files: 22% (513/2329)
Checking out files: 23% (536/2329)
Checking out files: 24% (559/2329)
Checking out files: 25% (583/2329)
Checking out files: 26% (606/2329)
Checking out files: 27% (629/2329)
Checking out files: 28% (653/2329)
Checking out files: 29% (676/2329)
Checking out files: 30% (699/2329)
Checking out files: 31% (722/2329)
Checking out files: 31% (744/2329)
Checking out files: 32% (746/2329)
Checking out files: 33% (769/2329)
Checking out files: 34% (792/2329)
Checking out files: 35% (816/2329)
Checking out files: 36% (839/2329)
Checking out files: 37% (862/2329)
Checking out files: 38% (886/2329)
Checking out files: 39% (909/2329)
Checking out files: 40% (932/2329)
Checking out files: 41% (955/2329)
Checking out files: 42% (979/2329)
Checking out files: 43% (1002/2329)
Checking out files: 44% (1025/2329)
Checking out files: 44% (1040/2329)
Checking out files: 45% (1049/2329)
Checking out files: 46% (1072/2329)
Checking out files: 47% (1095/2329)
Checking out files: 48% (1118/2329)
Checking out files: 49% (1142/2329)
Checking out files: 50% (1165/2329)
Checking out files: 51% (1188/2329)
Checking out files: 52% (1212/2329)
Checking out files: 53% (1235/2329)
Checking out files: 54% (1258/2329)
Checking out files: 55% (1281/2329)
Checking out files: 56% (1305/2329)
Checking out files: 57% (1328/2329)
Checking out files: 57% (1350/2329)
Checking out files: 58% (1351/2329)
Checking out files: 59% (1375/2329)
Checking out files: 60% (1398/2329)
Checking out files: 61% (1421/2329)
Checking out files: 62% (1444/2329)
Checking out files: 63% (1468/2329)
Checking out files: 64% (1491/2329)
Checking out files: 65% (1514/2329)
Checking out files: 66% (1538/2329)
Checking out files: 67% (1561/2329)
Checking out files: 68% (1584/2329)
Checking out files: 69% (1608/2329)
Checking out files: 70% (1631/2329)
Checking out files: 71% (1654/2329)
Checking out files: 72% (1677/2329)
Checking out files: 73% (1701/2329)
Checking out files: 74% (1724/2329)
Checking out files: 75% (1747/2329)
Checking out files: 76% (1771/2329)
Checking out files: 77% (1794/2329)
Checking out files: 78% (1817/2329)
Checking out files: 78% (1819/2329)
Checking out files: 79% (1840/2329)
Checking out files: 80% (1864/2329)
Checking out files: 81% (1887/2329)
Checking out files: 82% (1910/2329)
Checking out files: 83% (1934/2329)
Checking out files: 84% (1957/2329)
Checking out files: 85% (1980/2329)
Checking out files: 86% (2003/2329)
Checking out files: 87% (2027/2329)
Checking out files: 88% (2050/2329)
Checking out files: 89% (2073/2329)
Checking out files: 90% (2097/2329)
Checking out files: 91% (2120/2329)
Checking out files: 92% (2143/2329)
Checking out files: 92% (2163/2329)
Checking out files: 93% (2166/2329)
Checking out files: 94% (2190/2329)
Checking out files: 95% (2213/2329)
Checking out files: 96% (2236/2329)
Checking out files: 97% (2260/2329)
Checking out files: 98% (2283/2329)
Checking out files: 99% (2306/2329)
Checking out files: 100% (2329/2329)
Checking out files: 100% (2329/2329), done.
We will use a
TensorNode to
insert our TensorFlow code into Nengo. nengo_dl.TensorNode
works
very similarly to nengo.Node
, except instead of using the node to
insert Python code into our model we will use it to insert TensorFlow
code.
The first thing we need to do is define our TensorNode output. This
should be a function that accepts the current simulation time (and,
optionally, a batch of vectors) as input, and produces a batch of
vectors as output. All of these variables will be represented as
tf.Tensor
objects, and the internal operations of the TensorNode
will be implemented with TensorFlow operations. For example, we could
use a TensorNode to output a sin
function:
In [3]:
with nengo.Network() as net:
node = nengo_dl.TensorNode(lambda t: tf.reshape(tf.sin(t), (1, 1)))
p = nengo.Probe(node)
with nengo_dl.Simulator(net) as sim:
sim.run(5.0)
plt.figure()
plt.plot(sim.trange(), sim.data[p])
Building network
Build finished in 0:00:00
Optimization finished in 0:00:00
Construction finished in 0:00:00
Simulation finished in 0:00:01
Out[3]:
[<matplotlib.lines.Line2D at 0x2bdde833630>]
However, outputting a sin
function is something we could do more
easily with a regular nengo.Node
. The main use case for
nengo_dl.TensorNode
is to work with artificial neural networks that
are not easily defined in Nengo.
In this case we’re going to build a TensorNode that encapsulates the Inception-v1 network. Inception-v1 isn’t state-of-the-art anymore (we’re up to Inception-v4 now), but it is relatively small so it will be quick to download/run in this example. However, this same approach could be used for any TensorFlow network.
Inception-v1 performs image classification; if we show it an image, it will output a set of probabilities for the 1000 different object types it is trained to classify. So if we show it an image of a tree it should output a high probability for the “tree” class and a low probability for the “car” class.
The first thing we’ll do is download a sample image to test our network with (you could use a different image if you want).
In [4]:
url = 'https://upload.wikimedia.org/wikipedia/commons/7/70/EnglishCockerSpaniel_simon.jpg'
image_string = urlopen(url).read()
image = np.array(Image.open(io.BytesIO(image_string)))
image_shape = image.shape
# display the test image
plt.figure()
plt.imshow(image)
plt.axis('off');
Now we’re ready to create our TensorNode. Instead of using a function
for our TensorNode output, in this case we’ll use a callable class so
that we can include pre_build
/post_build
functions. These allow
us to execute code at different stages during the build process, which
may be necessary for more complicated TensorNodes.
The __call__
function is where we construct the TensorFlow elements
that will implement our node. It will take TensorFlow Tensors as input
and produce a tf.Tensor
as output, as with the tf.sin
example
above.
NengoDL will call the pre_build
function once when the model is
first constructed, so we can use this function to perform any initial
setup required for our node. In this case we’ll use the pre_build
function to download pre-trained weights for the Inception network. If
we wanted we could train the network from scratch using the
sim.train
function, but that would take a long time and require some
expertise in training deep networks.
The post_build
function is called after the rest of the graph has
been constructed (and whenever the simulation is reset). We’ll use this
to load the pretrained weights into the model. We have to do this at the
post_build
stage because we need access to the initialized
simulation session, which has the variables we want to load.
In [5]:
checkpoints_dir = '/tmp/checkpoints'
class InceptionNode(object):
def pre_build(self, *args):
# the shape of the inputs to the inception network
self.input_shape = inception.inception_v1.default_image_size
# download model checkpoint file
if not tf.gfile.Exists(checkpoints_dir):
tf.gfile.MakeDirs(checkpoints_dir)
url = "http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz"
dataset_utils.download_and_uncompress_tarball(url, checkpoints_dir)
def post_build(self, sess, rng):
# load checkpoint file into model
init_fn = slim.assign_from_checkpoint_fn(
os.path.join(checkpoints_dir, 'inception_v1.ckpt'),
slim.get_model_variables('InceptionV1'))
init_fn(sess)
def __call__(self, t, x):
# this is the function that will be executed each timestep while the
# network is running
# convert our input vector to the shape/dtype of the input image
image = tf.reshape(tf.cast(x, tf.uint8), image_shape)
# reshape the image to the shape expected by the inception network
processed_image = inception_preprocessing.preprocess_image(
image, self.input_shape, self.input_shape, is_training=False)
processed_images = tf.expand_dims(processed_image, 0)
# create inception network
with slim.arg_scope(inception.inception_v1_arg_scope()):
logits, _ = inception.inception_v1(processed_images,
num_classes=1001,
is_training=False)
probabilities = tf.nn.softmax(logits)
# return our classification probabilites
return probabilities
Next we create a Nengo Network, containing our TensorNode.
In [6]:
with nengo.Network() as net:
# create a normal input node to feed in our test image
input_node = nengo.Node(output=image.flatten())
# create our TensorNode containing the InceptionNode() we defined
# above. we also need to specify size_in (the dimensionality of
# our input vectors, the flattened images) and size_out (the number
# of classification classes output by the inception network)
incep_node = nengo_dl.TensorNode(
InceptionNode(), size_in=np.prod(image_shape), size_out=1001)
# connect up our input to our inception node
nengo.Connection(input_node, incep_node, synapse=None)
# add some probes to collect data
input_p = nengo.Probe(input_node)
incep_p = nengo.Probe(incep_node)
Note that at this point we could connect up the output of incep_node
to any other part of our network, if this was part of a larger model.
But to keep this example simple we’ll stop here.
All that’s left is to run our network, using our example image as input, and check the output.
In [7]:
# run the network for one timestep
with nengo_dl.Simulator(net) as sim:
sim.step()
# sort the output labels based on the classification probabilites
# output from the network
probabilities = sim.data[incep_p][0]
sorted_inds = [i[0] for i in sorted(enumerate(-probabilities),
key=lambda x: x[1])]
# print top 5 classes
names = imagenet.create_readable_names_for_imagenet_labels()
for i in range(5):
index = sorted_inds[i]
print('Probability %0.2f%% => [%s]' % (
probabilities[index] * 100, names[index]))
# display the test image
plt.figure()
plt.imshow(sim.data[input_p][0].reshape(image_shape).astype(np.uint8))
plt.axis('off');
Building network
Build finished in 0:00:00
Optimization finished in 0:00:00
>> Downloading inception_v1_2016_08_28.tar.gz 100.0%%) | ETA: 0:00:04
Successfully downloaded inception_v1_2016_08_28.tar.gz 24642554 bytes.
Construction finished in 0:00:06
Probability 44.95% => [cocker spaniel, English cocker spaniel, cocker]
Probability 22.56% => [Sussex spaniel]
Probability 10.18% => [Irish setter, red setter]
Probability 4.48% => [Welsh springer spaniel]
Probability 3.42% => [clumber, clumber spaniel]
In [8]:
# delete the models repo we cloned
def onerror(func, path, exc_info):
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise
shutil.rmtree("models", onerror=onerror)