"""
Tools for visualizing deep neural networks
"""
import collections
import numpy as np
import PIL.ImageTk
from .compat import tkinter as tk
[docs]class ImageSelector(tk.Frame):
"""Choose between images with left and right arrows
Attributes
----------
image_function : callable
Turn an array into a PIL.Image. Can be generated by
``nengo_extras.gui.image_function``.
resample : int
Resampling mode for ``PIL.Image.resize``.
"""
def __init__(self, parent, **kwargs):
tk.Frame.__init__(self, parent, **kwargs)
self.image_function = None
self.resample = 0
self.canvas = tk.Canvas(self)
self.canvas_image = self.canvas.create_image(0, 0, anchor='nw')
self.left = tk.Button(self, text='<')
self.right = tk.Button(self, text='>')
self.canvas.bind('<Configure>', self._on_canvas_resize)
# --- layout
self.left.pack(side='left', expand=False)
self.right.pack(side='right', expand=False)
self.canvas.pack(fill='both', expand=True)
def set_left_callback(self, callback):
self.left.configure(command=callback)
def set_right_callback(self, callback):
self.right.configure(command=callback)
def _canvas_resized(self):
h = self.canvas.winfo_height()
w = self.canvas.winfo_width()
image = self.image.resize((w, h), self.resample)
photo = PIL.ImageTk.PhotoImage(image)
self.photo = photo
self.canvas.itemconfig(self.canvas_image, image=photo)
def _on_canvas_resize(self, event):
self._canvas_resized()
def set_image(self, image_array):
self.image = self.image_function(image_array)
self._canvas_resized()
[docs]class VerticalImageFrame(ScrollCanvasFrame):
Column = collections.namedtuple(
'Column', ('width', 'height', 'image_function'))
def __init__(self, parent, **kwargs):
ScrollCanvasFrame.__init__(
self, parent, vertical=True, **kwargs)
self.padding = (5, 5) # width, height
self.columns = []
self.canvas_images = [] # by row, then column
self.images = [] # by row, then column
self.photos = [] # by row, then column
@property
def n_columns(self):
return len(self.columns)
@property
def n_rows(self):
return len(self.canvas_images)
@property
def column_heights(self):
return (column.height for column in self.columns)
@property
def column_widths(self):
return (column.width for column in self.columns)
@property
def column_image_functions(self):
return (column.image_function for column in self.columns)
def add_column(self, width, height, image_function):
self.columns.append(self.Column(width, height, image_function))
self.canvas.configure(width=sum(self.column_widths))
def create_canvas_images(self, n):
for image_row in self.canvas_images:
for image in image_row:
self.canvas.delete(image)
del self.canvas_images[:]
max_height = max(self.column_heights)
widths = np.array(list(self.column_widths))
pos_x = np.cumsum([0] + list(widths + self.padding[0]))[:-1]
for i in range(n):
image_row = [
self.canvas.create_image(
pos_x[j], i*(max_height + self.padding[1]), anchor='nw')
for j in range(self.n_columns)]
self.canvas_images.append(image_row)
def set_images(self, column_images):
assert len(column_images) == self.n_columns
n_rows = max(len(images) for images in column_images)
if n_rows > self.n_rows:
self.create_canvas_images(n_rows)
del self.images[:]
del self.photos[:]
for i in range(self.n_rows):
images = []
photos = []
for j in range(self.n_columns):
width, height, image_function = self.columns[j]
f = (column_images[j][i] if i < len(column_images[j]) else
np.zeros_like(column_images[j][0]))
image = image_function(f)
photo = PIL.ImageTk.PhotoImage(image.resize((width, height)))
self.canvas.itemconfig(self.canvas_images[i][j], image=photo)
images.append(image)
photos.append(photo)
self.images.append(images)
self.photos.append(photos)
[docs]class Viewer(tk.Tk):
def __init__(self, images, image_function, *args, **kwargs):
tk.Tk.__init__(self, *args, **kwargs)
self.filter_size = (50, 50) # width, height
self.act_size = (150, 150) # width, height
# --- data
self.images = images
self.frame_kind_data = {}
# --- components
self.selector = ImageSelector(self)
self.selector.image_function = image_function
self.selector.set_left_callback(self._on_prev_image)
self.selector.set_right_callback(self._on_next_image)
self.pane_frame = ScrollWindow(self, horizontal=True)
self.panes = tk.PanedWindow(self.pane_frame, orient='horizontal')
self.pane_frame.set_window(self.panes)
self.frames = []
# --- layout
self.selector.pack(side='top', expand=False)
self.pane_frame.pack(fill='both', expand=True)
@property
def n_frames(self):
return len(self.frames)
def add_column(self, kinds, data, fns, title=None):
assert len(kinds) == len(data) == len(fns)
pane = tk.Frame(self.panes)
if title is not None:
label = tk.Label(pane, text=title)
label.pack(side='top', fill='both')
frame = VerticalImageFrame(pane)
for kind, fn in zip(kinds, fns):
if kind == 'acts':
frame.add_column(self.act_size[0], self.act_size[1], fn)
elif kind == 'filters':
frame.add_column(self.filter_size[0], self.filter_size[1], fn)
else:
raise ValueError("Unrecognized kind %r" % kind)
frame.pack(fill='both', expand=True)
self.frames.append(frame)
self.frame_kind_data[frame] = tuple(zip(kinds, data))
self.panes.add(pane)
def add_acts(self, acts, acts_fn, **kwargs):
self.add_column(('acts',), (acts,), (acts_fn,), **kwargs)
def add_filters(self, filters, filters_fn, **kwargs):
self.add_column(('filters',), (filters,), (filters_fn,), **kwargs)
def add_filters_acts(self, filters, filters_fn, acts, acts_fn, **kwargs):
self.add_column(
('filters', 'acts'), (filters, acts), (filters_fn, acts_fn),
**kwargs)
def set_index(self, i):
self.i = i % len(self.images)
self.selector.set_image(self.images[self.i])
is_static = lambda kind: (kind == 'filters')
for frame in self.frames:
images = [d if is_static(k) else d[self.i]
for k, d in self.frame_kind_data[frame]]
frame.set_images(images)
def _on_prev_image(self):
self.set_index(self.i - 1)
def _on_next_image(self):
self.set_index(self.i + 1)