Source code for nengo_extras.networks.matrix_multiplication

import numpy as np

import nengo


[docs]def MatrixMult(n_neurons, shape_left, shape_right, net=None): """Computes the matrix product A*B. Both matrices need to be two dimensional. See the `Matrix Multiplication example <:doc:networks>`_ for a description of the network internals. Parameters ---------- n_neurons : int Number of neurons used per product of two scalars. .. note:: If an odd number of neurons is given, one less neuron will be used per product to obtain an even number. This is due to the implementation the `.Product` network. shape_left : tuple Shape of the A input matrix. shape_right : tuple Shape of the B input matrix. net : Network, optional (Default: None) A network in which the network components will be built. This is typically used to provide a custom set of Nengo object defaults through modifying ``net.config``. Returns ------- net : Network The newly built matrix multiplication network, or the provided ``net``. """ if len(shape_left) != 2: raise ValueError("Shape {} is not two dimensional.".format(shape_left)) if len(shape_right) != 2: raise ValueError( "Shape {} is not two dimensional.".format(shape_right)) if shape_left[1] != shape_right[0]: raise ValueError( "Matrix dimensions {} and {} are incompatible".format( shape_left, shape_right)) if net is None: net = nengo.Network(label="Matrix multiplication") size_left = np.prod(shape_left) size_right = np.prod(shape_right) with net: net.input_left = nengo.Node(size_in=size_left) net.input_right = nengo.Node(size_in=size_right) # The C matrix is composed of populations that each contain # one element of A (left) and one element of B (right). # These elements will be multiplied together in the next step. size_c = size_left * shape_right[1] net.C = nengo.networks.Product(n_neurons, size_c) # Determine the transformation matrices to get the correct pairwise # products computed. This looks a bit like black magic but if # you manually try multiplying two matrices together, you can see # the underlying pattern. Basically, we need to build up D1*D2*D3 # pairs of numbers in C to compute the product of. If i,j,k are the # indexes into the D1*D2*D3 products, we want to compute the product # of element (i,j) in A with the element (j,k) in B. The index in # A of (i,j) is j+i*D2 and the index in B of (j,k) is k+j*D3. # The index in C is j+k*D2+i*D2*D3, multiplied by 2 since there are # two values per ensemble. We add 1 to the B index so it goes into # the second value in the ensemble. transform_left = np.zeros((size_c, size_left)) transform_right = np.zeros((size_c, size_right)) for i, j, k in np.ndindex(shape_left[0], *shape_right): c_index = (j + k * shape_right[0] + i * size_right) transform_left[c_index][j + i * shape_right[0]] = 1 transform_right[c_index][k + j * shape_right[1]] = 1 nengo.Connection( net.input_left, net.C.A, transform=transform_left, synapse=None) nengo.Connection( net.input_right, net.C.B, transform=transform_right, synapse=None) # Now do the appropriate summing size_output = shape_left[0] * shape_right[1] net.output = nengo.Node(size_in=size_output) # The mapping for this transformation is much easier, since we want to # combine D2 pairs of elements (we sum D2 products together) transform_c = np.zeros((size_output, size_c)) for i in range(size_c): transform_c[i // shape_right[0]][i] = 1 nengo.Connection( net.C.output, net.output, transform=transform_c, synapse=None) return net