"""
Build classes for Nengo transform operators.
"""
import warnings
import numpy as np
import tensorflow as tf
from nengo_dl.builder import Builder, OpBuilder
from nengo_dl.compat import tf_convolution, ConvInc
[docs]@Builder.register(ConvInc)
class ConvIncBuilder(OpBuilder):
"""
Build a group of ``ConvInc`` operators.
"""
# TODO: fix link to `~nengo.builder.transform.ConvInc` once it exists
def __init__(self, ops, signals, config):
super(ConvIncBuilder, self).__init__(ops, signals, config)
self.conv = ops[0].conv
if not self.conv.channels_last and config.cpu_only:
# TensorFlow doesn't support channels first on CPU, so if
# tensorflow-gpu isn't installed we need to force channels_last
# TODO: check if this is supported in future versions
warnings.warn(
"TensorFlow does not support convolution with "
"channels_last=False on the CPU; inputs will be transformed "
"to channels_last=True",
UserWarning,
)
force_last = True
else:
force_last = False
# create data format string
fmts = ["W", "HW", "DHW"]
if self.conv.dimensions > len(fmts):
raise NotImplementedError(
"Convolutions > %d dimensions are not supported" % len(fmts)
)
fmt = fmts[self.conv.dimensions - 1]
self.fmt = (
"N" + fmt + "C" if self.conv.channels_last or force_last else "NC" + fmt
)
self.W_data = signals.combine([op.W for op in ops])
# all the ops have the same input, so we just use one
self.X_data = signals[ops[0].X]
self.X_data = self.X_data.reshape(self.conv.input_shape.shape)
self.Y_data = signals.combine([op.Y for op in ops])
assert self.X_data.minibatched
if self.W_data.minibatched:
raise NotImplementedError(
"Minibatched convolutional weights are not supported"
)
# set up X transformations
# move batch to front
perm_x = np.roll(np.arange(self.conv.dimensions + 2), 1)
if force_last:
# move channel dimension to the end
perm_x[1:-1] = perm_x[2:]
perm_x[-1] = 0
self.perm_x = signals.constant(perm_x)
# set up Y transformations
if len(ops) > 1:
if self.conv.channels_last or force_last:
# separate last dimension into output for each op
reshape_y = (
(signals.minibatch_size,)
+ self.conv.output_shape.spatial_shape
+ (-1, len(ops))
)
# move ops to front and batch to end
perm_y = np.arange(self.conv.dimensions + 3)
perm_y[[0, -1]] = perm_y[[-1, 0]]
if force_last:
# move channel dimension back to the front
perm_y[1:-1] = perm_y[:-2]
perm_y[1] = len(perm_y) - 2
else:
reshape_y = (
signals.minibatch_size,
-1,
len(ops),
) + self.conv.output_shape.spatial_shape
perm_y = (2, 1) + tuple(range(3, self.conv.dimensions + 3)) + (0,)
self.reshape_y = signals.constant(reshape_y)
self.perm_y = signals.constant(perm_y)
else:
self.reshape_y = None
# move batch to end
perm_y = np.roll(np.arange(self.conv.dimensions + 2), -1)
if force_last:
perm_y[1:-1] = perm_y[:-2]
perm_y[0] = len(perm_y) - 1
self.perm_y = signals.constant(perm_y)
# set up W transformations
if len(ops) > 1:
# move ops to end
self.W_data = self.W_data.reshape((len(ops),) + self.conv.kernel_shape)
self.perm_w = signals.constant(
np.roll(np.arange(self.conv.dimensions + 3), -1)
)
# concatenate weights for each op along output channel dimension
self.reshape_w = signals.constant(
self.conv.kernel_size + (self.conv.input_shape.n_channels, -1)
)
else:
self.perm_w = None
self.reshape_w = None
[docs] def build_step(self, signals):
W = signals.gather(self.W_data)
X = signals.gather(self.X_data)
# put batch dimension first
X = tf.transpose(a=X, perm=self.perm_x)
if self.perm_w is not None:
# concatenate kernels along output channel dimension
W = tf.transpose(a=W, perm=self.perm_w)
W = tf.reshape(W, self.reshape_w)
Y = tf_convolution(
input=X,
filters=W,
strides=self.conv.strides,
data_format=self.fmt,
padding=self.conv.padding.upper(),
)
# move batch back to end, ops to front
if self.reshape_y is not None:
Y = tf.reshape(Y, self.reshape_y)
Y = tf.transpose(a=Y, perm=self.perm_y)
signals.scatter(self.Y_data, Y, mode="inc")
[docs] @staticmethod
def mergeable(x, y):
# we allow convolutions to merge if they have the same input signal
# (as then we can efficiently apply several kernels to the same input).
# padding/strides/channels/shape also have to match.
return (
x.X is y.X
and x.conv.input_shape.shape == y.conv.input_shape.shape
and x.conv.strides == y.conv.strides
and x.conv.padding == y.conv.padding
and x.conv.channels_last == y.conv.channels_last
)