Source code for nengo_dl.transform_builders

"""
Build classes for Nengo transform operators.
"""

import warnings

import numpy as np
import tensorflow as tf

from nengo_dl.builder import Builder, OpBuilder
from nengo_dl.compat import tf_convolution, ConvInc


[docs]@Builder.register(ConvInc) class ConvIncBuilder(OpBuilder): """ Build a group of ``ConvInc`` operators. """ # TODO: fix link to `~nengo.builder.transform.ConvInc` once it exists def __init__(self, ops, signals, config): super(ConvIncBuilder, self).__init__(ops, signals, config) self.conv = ops[0].conv if not self.conv.channels_last and config.cpu_only: # TensorFlow doesn't support channels first on CPU, so if # tensorflow-gpu isn't installed we need to force channels_last # TODO: check if this is supported in future versions warnings.warn( "TensorFlow does not support convolution with " "channels_last=False on the CPU; inputs will be transformed " "to channels_last=True", UserWarning, ) force_last = True else: force_last = False # create data format string fmts = ["W", "HW", "DHW"] if self.conv.dimensions > len(fmts): raise NotImplementedError( "Convolutions > %d dimensions are not supported" % len(fmts) ) fmt = fmts[self.conv.dimensions - 1] self.fmt = ( "N" + fmt + "C" if self.conv.channels_last or force_last else "NC" + fmt ) self.W_data = signals.combine([op.W for op in ops]) # all the ops have the same input, so we just use one self.X_data = signals[ops[0].X] self.X_data = self.X_data.reshape(self.conv.input_shape.shape) self.Y_data = signals.combine([op.Y for op in ops]) assert self.X_data.minibatched if self.W_data.minibatched: raise NotImplementedError( "Minibatched convolutional weights are not supported" ) # set up X transformations # move batch to front perm_x = np.roll(np.arange(self.conv.dimensions + 2), 1) if force_last: # move channel dimension to the end perm_x[1:-1] = perm_x[2:] perm_x[-1] = 0 self.perm_x = signals.constant(perm_x) # set up Y transformations if len(ops) > 1: if self.conv.channels_last or force_last: # separate last dimension into output for each op reshape_y = ( (signals.minibatch_size,) + self.conv.output_shape.spatial_shape + (-1, len(ops)) ) # move ops to front and batch to end perm_y = np.arange(self.conv.dimensions + 3) perm_y[[0, -1]] = perm_y[[-1, 0]] if force_last: # move channel dimension back to the front perm_y[1:-1] = perm_y[:-2] perm_y[1] = len(perm_y) - 2 else: reshape_y = ( signals.minibatch_size, -1, len(ops), ) + self.conv.output_shape.spatial_shape perm_y = (2, 1) + tuple(range(3, self.conv.dimensions + 3)) + (0,) self.reshape_y = signals.constant(reshape_y) self.perm_y = signals.constant(perm_y) else: self.reshape_y = None # move batch to end perm_y = np.roll(np.arange(self.conv.dimensions + 2), -1) if force_last: perm_y[1:-1] = perm_y[:-2] perm_y[0] = len(perm_y) - 1 self.perm_y = signals.constant(perm_y) # set up W transformations if len(ops) > 1: # move ops to end self.W_data = self.W_data.reshape((len(ops),) + self.conv.kernel_shape) self.perm_w = signals.constant( np.roll(np.arange(self.conv.dimensions + 3), -1) ) # concatenate weights for each op along output channel dimension self.reshape_w = signals.constant( self.conv.kernel_size + (self.conv.input_shape.n_channels, -1) ) else: self.perm_w = None self.reshape_w = None
[docs] def build_step(self, signals): W = signals.gather(self.W_data) X = signals.gather(self.X_data) # put batch dimension first X = tf.transpose(a=X, perm=self.perm_x) if self.perm_w is not None: # concatenate kernels along output channel dimension W = tf.transpose(a=W, perm=self.perm_w) W = tf.reshape(W, self.reshape_w) Y = tf_convolution( input=X, filters=W, strides=self.conv.strides, data_format=self.fmt, padding=self.conv.padding.upper(), ) # move batch back to end, ops to front if self.reshape_y is not None: Y = tf.reshape(Y, self.reshape_y) Y = tf.transpose(a=Y, perm=self.perm_y) signals.scatter(self.Y_data, Y, mode="inc")
[docs] @staticmethod def mergeable(x, y): # we allow convolutions to merge if they have the same input signal # (as then we can efficiently apply several kernels to the same input). # padding/strides/channels/shape also have to match. return ( x.X is y.X and x.conv.input_shape.shape == y.conv.input_shape.shape and x.conv.strides == y.conv.strides and x.conv.padding == y.conv.padding and x.conv.channels_last == y.conv.channels_last )