"""
Objects to be used with the Keras callback functionality.
See https://www.tensorflow.org/guide/keras/custom_callback for more information
on how to use Keras callbacks.
The short answer is that these can be passed to, e.g., `.Simulator.fit` like
.. code-block:: python
sim.fit(..., callbacks=[nengo_dl.callbacks.NengoSummaries(...)]
"""
import contextlib
import nengo
from nengo.exceptions import ValidationError
import tensorflow as tf
from tensorflow.python.eager import context
from nengo_dl import compat, utils
[docs]class NengoSummaries(tf.keras.callbacks.Callback):
"""
Logs the values of Nengo object parameters, to be displayed in TensorBoard.
See https://www.tensorflow.org/tensorboard/get_started for general instructions
on using TensorBoard.
Parameters
----------
log_dir : str
Directory where log file will be written.
sim : `.Simulator`
Simulator object which will be used to look up parameter values.
objects : list of `nengo.Ensemble` or `nengo.ensemble.Neurons` or `nengo.Connection`
The object whose parameter values we want to record (passing an Ensemble will
log its encoders, Neurons will log biases, and Connection will log connection
weights/decoders).
"""
def __init__(self, log_dir, sim, objects):
super().__init__()
self.sim = sim
with contextlib.suppress() if compat.eager_enabled() else context.eager_mode():
self.writer = tf.summary.create_file_writer(log_dir)
self.summaries = []
for obj in objects:
if isinstance(
obj, (nengo.Ensemble, nengo.ensemble.Neurons, nengo.Connection)
):
if isinstance(obj, nengo.Ensemble):
param = "encoders"
name = "Ensemble_%s" % obj.label
elif isinstance(obj, nengo.ensemble.Neurons):
param = "bias"
name = "Ensemble.neurons_%s" % obj.ensemble.label
elif isinstance(obj, nengo.Connection):
if not compat.conn_has_weights(obj):
raise ValidationError(
"Connection '%s' does not have any weights to log" % obj,
"objects",
)
param = "weights"
name = "Connection_%s" % obj.label
self.summaries.append(
(utils.sanitize_name("%s_%s" % (name, param)), obj, param)
)
else:
raise ValidationError(
"Unknown summary object %s; should be an Ensemble, Neurons, or "
"Connection" % obj,
"objects",
)
[docs] def on_epoch_end(self, epoch, logs=None):
"""Log parameter values at the end of each epoch."""
summary_vals = self.sim.data.get_params(
*[(obj, attr) for _, obj, attr in self.summaries]
)
with (
contextlib.suppress() if compat.eager_enabled() else context.eager_mode()
), self.writer.as_default():
for (name, _, _), val in zip(self.summaries, summary_vals):
tf.summary.histogram(name, val, step=epoch)
[docs] def on_train_end(self, logs=None):
"""Close summary writer at end of training."""
with contextlib.suppress() if compat.eager_enabled() else context.eager_mode():
self.writer.close()
[docs]class TensorBoard(tf.keras.callbacks.TensorBoard):
"""
A version of the Keras TensorBoard callback that also profiles inference.
"""
[docs] def on_predict_batch_end(self, *args, **kwargs):
"""Redirect to training function."""
self.on_batch_end(*args, **kwargs)
[docs] def on_predict_begin(self, *args, **kwargs):
"""Redirect to training function."""
self.on_train_begin(*args, **kwargs)
[docs] def on_predict_end(self, *args, **kwargs):
"""Redirect to training function."""
self.on_train_end(*args, **kwargs)
[docs]class IsolateState(tf.keras.callbacks.Callback):
"""
Isolate the internal state of the simulation from any other stateful operations.
This will cause every batch to begin from the same initial state (the state of
the simulation whenever this callback is created). And when this operation
completes, the simulation state will be returned to that initial state.
Parameters
----------
sim : `.Simulator`
The Simulator containing the state we want to control.
"""
def __init__(self, sim):
super().__init__()
self.sim = sim
self.saved_state = (
None
if sim.n_steps == 0
else tf.keras.backend.batch_get_value(
list(sim.tensor_graph.saved_state.values())
)
)
[docs] def reset(self):
"""Resets the simulation state to the saved state."""
if self.saved_state is None:
self.sim.reset(
include_probes=False, include_trainable=False, include_processes=False
)
else:
tf.keras.backend.batch_set_value(
list(zip(self.sim.tensor_graph.saved_state.values(), self.saved_state))
)
self.sim._update_steps()
[docs] def on_train_batch_end(self, batch, logs=None):
"""Reset state at the end of each batch."""
self.reset()
[docs] def on_predict_batch_end(self, batch, logs=None):
"""Reset state at the end of each batch."""
self.reset()
[docs] def on_test_batch_end(self, batch, logs=None):
"""Reset state at the end of each batch."""
self.reset()