Source code for nengo_extras.deepview

"""
Tools for visualizing deep neural networks
"""
import collections

import numpy as np
import PIL.ImageTk

from .compat import tkinter as tk


[docs]class ImageSelector(tk.Frame): """Choose between images with left and right arrows Attributes ---------- image_function : callable Turn an array into a PIL.Image. Can be generated by ``nengo_extras.gui.image_function``. resample : int Resampling mode for ``PIL.Image.resize``. """ def __init__(self, parent, **kwargs): tk.Frame.__init__(self, parent, **kwargs) self.image_function = None self.resample = 0 self.canvas = tk.Canvas(self) self.canvas_image = self.canvas.create_image(0, 0, anchor='nw') self.left = tk.Button(self, text='<') self.right = tk.Button(self, text='>') self.canvas.bind('<Configure>', self._on_canvas_resize) # --- layout self.left.pack(side='left', expand=False) self.right.pack(side='right', expand=False) self.canvas.pack(fill='both', expand=True) def set_left_callback(self, callback): self.left.configure(command=callback) def set_right_callback(self, callback): self.right.configure(command=callback) def _canvas_resized(self): h = self.canvas.winfo_height() w = self.canvas.winfo_width() image = self.image.resize((w, h), self.resample) photo = PIL.ImageTk.PhotoImage(image) self.photo = photo self.canvas.itemconfig(self.canvas_image, image=photo) def _on_canvas_resize(self, event): self._canvas_resized() def set_image(self, image_array): self.image = self.image_function(image_array) self._canvas_resized()
[docs]class ScrollCanvasFrame(tk.Frame): def __init__(self, parent, vertical=False, horizontal=False, **kwargs): tk.Frame.__init__(self, parent, **kwargs) self.vertical = vertical self.horizontal = horizontal # self.canvas = tk.Canvas(frame, height=size[0], width=size[1]) self.canvas = tk.Canvas(self) if vertical: self.vscrollbar = tk.Scrollbar( self, orient='vertical', command=self.canvas.yview) self.canvas.configure(yscrollcommand=self.vscrollbar.set) self.vscrollbar.pack(side='right', fill='y') if horizontal: self.hscrollbar = tk.Scrollbar( self, orient='horizontal', command=self.canvas.xview) self.canvas.configure(xscrollcommand=self.hscrollbar.set) self.hscrollbar.pack(side='bottom', fill='both') self.canvas.pack(fill='both', expand=True) self.bind('<Configure>', self._on_resize) def _on_resize(self, event): self.canvas.configure(scrollregion=self.canvas.bbox('all'))
[docs]class ScrollWindow(ScrollCanvasFrame): """Scroll a single window""" def __init__(self, *args, **kwargs): ScrollCanvasFrame.__init__(self, *args, **kwargs) self.item = self.canvas.create_window((0, 0), anchor='nw') self.canvas.bind('<Configure>', self._on_canvas_resize) def set_window(self, window): self.canvas.itemconfig(self.item, window=window) def _on_canvas_resize(self, event): if not self.vertical and not self.horizontal: h = self.canvas.winfo_height() w = self.canvas.winfo_width() self.canvas.itemconfig(self.item, height=h, width=w) elif not self.vertical: h = self.canvas.winfo_height() self.canvas.itemconfig(self.item, height=h) elif not self.horizontal: w = self.canvas.winfo_width() self.canvas.itemconfig(self.pane_item, width=w)
[docs]class VerticalImageFrame(ScrollCanvasFrame): Column = collections.namedtuple( 'Column', ('width', 'height', 'image_function')) def __init__(self, parent, **kwargs): ScrollCanvasFrame.__init__( self, parent, vertical=True, **kwargs) self.padding = (5, 5) # width, height self.columns = [] self.canvas_images = [] # by row, then column self.images = [] # by row, then column self.photos = [] # by row, then column @property def n_columns(self): return len(self.columns) @property def n_rows(self): return len(self.canvas_images) @property def column_heights(self): return (column.height for column in self.columns) @property def column_widths(self): return (column.width for column in self.columns) @property def column_image_functions(self): return (column.image_function for column in self.columns) def add_column(self, width, height, image_function): self.columns.append(self.Column(width, height, image_function)) self.canvas.configure(width=sum(self.column_widths)) def create_canvas_images(self, n): for image_row in self.canvas_images: for image in image_row: self.canvas.delete(image) del self.canvas_images[:] max_height = max(self.column_heights) widths = np.array(list(self.column_widths)) pos_x = np.cumsum([0] + list(widths + self.padding[0]))[:-1] for i in range(n): image_row = [ self.canvas.create_image( pos_x[j], i*(max_height + self.padding[1]), anchor='nw') for j in range(self.n_columns)] self.canvas_images.append(image_row) def set_images(self, column_images): assert len(column_images) == self.n_columns n_rows = max(len(images) for images in column_images) if n_rows > self.n_rows: self.create_canvas_images(n_rows) del self.images[:] del self.photos[:] for i in range(self.n_rows): images = [] photos = [] for j in range(self.n_columns): width, height, image_function = self.columns[j] f = (column_images[j][i] if i < len(column_images[j]) else np.zeros_like(column_images[j][0])) image = image_function(f) photo = PIL.ImageTk.PhotoImage(image.resize((width, height))) self.canvas.itemconfig(self.canvas_images[i][j], image=photo) images.append(image) photos.append(photo) self.images.append(images) self.photos.append(photos)
[docs]class Viewer(tk.Tk): def __init__(self, images, image_function, *args, **kwargs): tk.Tk.__init__(self, *args, **kwargs) self.filter_size = (50, 50) # width, height self.act_size = (150, 150) # width, height # --- data self.images = images self.frame_kind_data = {} # --- components self.selector = ImageSelector(self) self.selector.image_function = image_function self.selector.set_left_callback(self._on_prev_image) self.selector.set_right_callback(self._on_next_image) self.pane_frame = ScrollWindow(self, horizontal=True) self.panes = tk.PanedWindow(self.pane_frame, orient='horizontal') self.pane_frame.set_window(self.panes) self.frames = [] # --- layout self.selector.pack(side='top', expand=False) self.pane_frame.pack(fill='both', expand=True) @property def n_frames(self): return len(self.frames) def add_column(self, kinds, data, fns, title=None): assert len(kinds) == len(data) == len(fns) pane = tk.Frame(self.panes) if title is not None: label = tk.Label(pane, text=title) label.pack(side='top', fill='both') frame = VerticalImageFrame(pane) for kind, fn in zip(kinds, fns): if kind == 'acts': frame.add_column(self.act_size[0], self.act_size[1], fn) elif kind == 'filters': frame.add_column(self.filter_size[0], self.filter_size[1], fn) else: raise ValueError("Unrecognized kind %r" % kind) frame.pack(fill='both', expand=True) self.frames.append(frame) self.frame_kind_data[frame] = tuple(zip(kinds, data)) self.panes.add(pane) def add_acts(self, acts, acts_fn, **kwargs): self.add_column(('acts',), (acts,), (acts_fn,), **kwargs) def add_filters(self, filters, filters_fn, **kwargs): self.add_column(('filters',), (filters,), (filters_fn,), **kwargs) def add_filters_acts(self, filters, filters_fn, acts, acts_fn, **kwargs): self.add_column( ('filters', 'acts'), (filters, acts), (filters_fn, acts_fn), **kwargs) def set_index(self, i): self.i = i % len(self.images) self.selector.set_image(self.images[self.i]) is_static = lambda kind: (kind == 'filters') for frame in self.frames: images = [d if is_static(k) else d[self.i] for k, d in self.frame_kind_data[frame]] frame.set_images(images) def _on_prev_image(self): self.set_index(self.i - 1) def _on_next_image(self): self.set_index(self.i + 1)