"""
The builder manages the mapping between (groups of) Nengo operators and the
builder objects that know how to translate those operators into a
TensorFlow graph.
"""
from collections import namedtuple
import logging
import warnings
from nengo import builder
from nengo.builder import signal
from nengo.exceptions import BuildError
import numpy as np
import tensorflow as tf
from nengo_dl import utils
logger = logging.getLogger(__name__)
[docs]class Builder:
"""
Manages the operator build classes known to the ``nengo_dl`` build process.
Parameters
----------
plan : list of tuple of `~nengo.builder.Operator`
The groups of operators that will be built
"""
builders = {}
def __init__(self, plan):
self.plan = plan
self.op_builds = {}
for ops in self.plan:
if type(ops[0]) not in Builder.builders:
raise BuildError(
"No registered builder for operators of type %r" % type(ops[0])
)
self.op_builds[ops] = Builder.builders[type(ops[0])](ops)
[docs] def build_pre(self, signals, config, progress=None):
"""
Setup step for build classes, in which they compute any of the
values that are constant across simulation timesteps.
Parameters
----------
signals : `.signals.SignalDict`
Mapping from `~nengo.builder.Signal` to
``tf.Tensor`` (updated by operations)
config : `.BuildConfig`
Configuration parameters for the build process
progress : `.utils.ProgressBar`
Progress bar for ops in plan
"""
for ops in self.plan:
logger.debug("===================")
logger.debug("PRE BUILD %s", ops)
logger.debug("sets %s", [op.sets for op in ops])
logger.debug("incs %s", [op.incs for op in ops])
logger.debug("reads %s", [op.reads for op in ops])
logger.debug("updates %s", [op.updates for op in ops])
with self.name_scope(ops):
self.op_builds[ops].build_pre(signals, config)
if progress is not None:
progress.step()
[docs] def build_step(self, signals, progress=None):
"""
Build the computations implementing a single simulator timestep.
Parameters
----------
signals : `.signals.SignalDict`
Mapping from `~nengo.builder.Signal` to
``tf.Tensor`` (updated by operations)
progress : `.utils.ProgressBar`
Progress bar for ops in plan
Returns
-------
side_effects : list of ``tf.Tensor``
Outputs with possible side-effects, i.e. computations that need to
be executed in the TensorFlow graph even if their output doesn't
appear to be used.
"""
side_effects = []
for ops in self.plan:
logger.debug("===================")
logger.debug("BUILD %s", ops)
with self.name_scope(ops):
output = self.op_builds[ops].build_step(signals)
if isinstance(output, (tf.Tensor, tf.Variable)):
side_effects.append(output)
elif isinstance(output, (list, tuple)):
side_effects.extend(list(output))
if progress is not None:
progress.step()
return side_effects
[docs] def build_post(self, signals, progress=None):
"""
Calls post build functions for all ops in plan.
Parameters
----------
signals : `.signals.SignalDict`
Mapping from `~nengo.builder.Signal` to
``tf.Tensor`` (updated by operations)
progress : `.utils.ProgressBar`
Progress bar for ops in plan
"""
for ops in self.plan:
logger.debug("===================")
logger.debug("POST BUILD %s", ops)
with self.name_scope(ops):
self.op_builds[ops].build_post(signals)
if progress is not None:
progress.step()
[docs] def name_scope(self, ops):
"""Returns a new TensorFlow name scope for the given ops."""
return tf.name_scope(
utils.sanitize_name(Builder.builders[type(ops[0])].__name__)
)
[docs] @classmethod
def register(cls, nengo_op):
"""
A decorator for adding a class to the build function registry.
Parameters
----------
nengo_op : `~nengo.builder.Operator`
The operator associated with the build function being decorated.
"""
def register_builder(build_class):
if not issubclass(build_class, OpBuilder):
warnings.warn("Build classes should inherit from OpBuilder")
if nengo_op in cls.builders:
warnings.warn(
"Operator '%s' already has a builder. Overwriting." % nengo_op
)
cls.builders[nengo_op] = build_class
return build_class
return register_builder
[docs]class BuildConfig(
namedtuple(
"BuildConfig",
("inference_only", "lif_smoothing", "cpu_only", "rng", "training"),
)
):
"""
Stores configuration parameters that may be relevant to parts of the
build process.
Parameters
----------
inference_only : bool
If True the network should be constructed in "inference only" mode
(omitting any support for training operations).
lif_smoothing : float
Smoothing parameter for `~nengo.LIF` gradient approximation.
cpu_only : bool
True if TensorFlow is only running on the CPU (because that was
specified by the user or because GPU support is not available).
rng : `~numpy.random.RandomState`
Seeded random number generator.
training : ``tf.Tensor`` (bool)
True if building in training mode, False for inference mode.
"""
__slots__ = ()
[docs]class OpBuilder:
"""
Base class for build classes, which implement the logic for building a group of
Nengo Operators into TensorFlow.
"""
def __init__(self, ops):
"""
Initialize internal OpBuilder implementation.
Note: this should not be building any model operations, this is purely for
internal setup of the ``OpBuilder`` itself.
Parameters
----------
ops : list of `~nengo.builder.Operator`
The operator group to build into the model
"""
self.ops = ops
[docs] def build_pre(self, signals, config):
"""
This function should set up any computations that are fixed for
this op (i.e., things that do not need to be recomputed each timestep).
Parameters
----------
signals : `.signals.SignalDict`
Mapping from `~nengo.builder.Signal` to
``tf.Tensor`` (updated by operations)
config : `~.builder.BuildConfig`
General repository for config information builders might want
(conglomerated into this object so that we can add/remove config data
without having to change the function signature all the time).
"""
logger.debug(self.__class__.__name__)
logger.debug("\n".join(str(x) for x in self.ops))
self.config = config
[docs] def build_step(self, signals):
"""
This function builds whatever computations need to be executed in
each simulation timestep.
Parameters
----------
signals : `.signals.SignalDict`
Mapping from `~nengo.builder.Signal` to
``tf.Tensor`` (updated by operations)
Returns
-------
side_effects : list of ``tf.Tensor``
If not None, the returned tensors correspond to outputs with
possible side-effects, i.e. computations that need to be executed
in the TensorFlow graph even if their output doesn't appear to be
used
"""
raise BuildError("OpBuilders must implement a `build_step` function")
[docs] def build_post(self, signals):
"""
This function will be called after the graph has been built and
each time the Simulator is reset.
Note that this function may be called multiple times per session, so
it should do any required operations in-place.
Parameters
----------
signals : `.signals.SignalDict`
Mapping from `~nengo.builder.Signal` to
``tf.Tensor`` (updated by operations)
"""
[docs] @staticmethod
def mergeable(x, y):
"""
Compute the mergeability of two operators of this builder's type.
Parameters
----------
x : `nengo.builder.Operator`
The operator being tested
y : `nengo.builder.Operator`
The operator being merged into (this is representative of a group
of operators that have already been merged)
Returns
-------
mergeable : bool
True if ``x`` and ``y`` can be merged into a single built op,
else ``False``.
"""
return False
[docs]class NengoBuilder(builder.Builder):
"""
Copy of the default Nengo builder.
This class is here so that we can register new build functions for
Nengo DL without affecting the default Nengo build process.
"""
builders = {}
[docs] @classmethod
def build(cls, model, obj, *args, **kwargs):
"""
Build ``obj`` into ``model``.
This method looks up the appropriate build function for ``obj`` and
calls it with the model and other arguments provided.
In addition to the parameters listed below, further positional and
keyword arguments will be passed unchanged into the build function.
Parameters
----------
model : Model
The `~nengo.builder.Model` instance in which to store build
artifacts.
obj : object
The object to build into the model.
"""
try:
# first try building obj using any custom build functions that have
# been registered by Nengo DL
return builder.Builder.build.__func__(
NengoBuilder, model, obj, *args, **kwargs
)
except BuildError:
# fallback on normal nengo builder
return builder.Builder.build.__func__(
builder.Builder, model, obj, *args, **kwargs
)
[docs]class NengoModel(builder.Model):
"""
Copy of the default Nengo model.
This allows us to override certain model behaviours.
Parameters
----------
fail_fast : bool
If True, try to call ``op.make_step`` when ops are added to the model.
Note that NengoDL doesn't actually use ``make_step``, so errors in that
function are not necessarily errors in NengoDL (which is why we want to
disable that check). But it might still be useful when debugging
new op/build functions, which is why we leave the option.
"""
def __init__(self, *args, fail_fast=True, **kwargs):
self.fail_fast = fail_fast
super().__init__(*args, **kwargs)
[docs] def add_op(self, op):
"""
Add an operator to the model.
Parameters
----------
op : `~nengo.builder.Operator`
Operator being added to the model.
Notes
-----
This is a copy of the parent `nengo.builder.Model.add_op`, with the
addition of the ``if self.fail_fast`` condition.
"""
# TODO: nengo 3.0 adds something similar to this condition, but
# it uses an rc setting (so we can't change it in nengo-dl without
# also changing nengo core). if the rc system is reworked to allow
# backend-specific overrides, we could remove this class.
self.operators.append(op)
if self.fail_fast:
# Fail fast by trying make_step with a temporary sigdict
signals = signal.SignalDict()
op.init_signals(signals)
op.make_step(signals, self.dt, np.random)