Source code for nengo_extras.matplotlib

from __future__ import absolute_import

import matplotlib.pyplot as plt
import numpy as np


[docs]def imshow(image, ax=None, vmin=None, vmax=None, invert=False, interpolation='none', axes=False): """Nicer version of Matplotlib's imshow. - By default, show the raw image with no interpolation. - If the image is greyscale, use grey colormap. """ kwargs = dict(vmin=vmin, vmax=vmax, interpolation=interpolation) if image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1: image = image[:, :, 0] if image.ndim == 3 else image kwargs['cmap'] = 'gray' if not invert else 'gist_yarg' elif image.ndim == 3: assert image.shape[2] == 3 if vmin is not None and vmax is not None: image = (image.clip(vmin, vmax) - vmin) / (vmax - vmin) else: raise ValueError("Wrong number of image dimensions") ax = plt.gca() if ax is None else ax ax_img = ax.imshow(image, **kwargs) if not axes: ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) return ax_img
[docs]def tile(images, ax=None, rows=9, cols=12, grid=False, gridwidth=1, gridcolor='r', **show_params): """Plot a grid of images Parameters ---------- images : ndarray (n_images, height, width, channels) Array of images to display. """ if images.ndim == 3: images = images[:, :, :, None] n_images, m, n, nc = images.shape if n_images < rows * cols: aspect = float(cols) / rows rows = int(np.maximum(np.round(np.sqrt(n_images / aspect)), 1)) cols = int(np.ceil(float(n_images) / rows)) img_shape = (m*rows, n*cols, nc) img = np.zeros(img_shape, dtype=images.dtype) for k in range(min(rows*cols, n_images)): i, j = k // cols, k % cols img[i*m:(i+1)*m, j*n:(j+1)*n, :] = images[k] ax_img = imshow(img, ax=ax, **show_params) ax = ax_img.axes if grid: for i in range(1, rows): ax.plot([-0.5, img.shape[1]-0.5], [i*m-0.5, i*m-0.5], color=gridcolor, linewidth=gridwidth) for j in range(1, cols): ax.plot([j*n-0.5, j*n-0.5], [-0.5, img.shape[0]-0.5], color=gridcolor, linewidth=gridwidth) ax.set_xlim([-0.5, img.shape[1]-0.5]) ax.set_ylim([-0.5, img.shape[0]-0.5]) ax.invert_yaxis()
[docs]def compare(image_sets, ax=None, rows=4, cols=12, grid=True, gridwidth=1, gridcolor='r', **show_params): """Compare sets of images in a grid. Parameters ---------- image_sets : list of (n_images, height, width, channels) ndarray List of the sets of images to compare. Each set of images must be the same size. """ d = len(image_sets) n_images = image_sets[0].shape[0] imshape = image_sets[0].shape[1:] m, n = imshape[:2] nc = imshape[2] if len(imshape) > 2 else 1 for q in range(d): if image_sets[q].shape != image_sets[0].shape: raise ValueError( "All image sets must be the same shape as the first") if n_images < rows * cols: aspect = float(cols) / rows rows = int(np.maximum(np.round(np.sqrt(n_images / aspect)), 1)) cols = int(np.ceil(float(n_images) / rows)) img_shape = (d*m*rows, n*cols, nc) img = np.zeros(img_shape, dtype=image_sets[0].dtype) for k in range(min(rows*cols, n_images)): i, j = k // cols, k % cols for q in range(d): img[(d*i+q)*m:(d*i+q+1)*m, j*n:(j+1)*n] = \ image_sets[q][k, :].reshape(m, n, nc) ax_img = imshow(img, ax=ax, **show_params) ax = ax_img.axes if grid: for i in range(1, rows): ax.plot([-0.5, img.shape[1]-0.5], (d*i*m-0.5)*np.ones(2), color=gridcolor, linewidth=gridwidth) for j in range(1, cols): ax.plot([j*n-0.5, j*n-0.5], [-0.5, img.shape[0]-0.5], color=gridcolor, linewidth=gridwidth) ax.set_xlim([-0.5, img.shape[1]-0.5]) ax.set_ylim([-0.5, img.shape[0]-0.5]) ax.invert_yaxis()