API reference¶
Modules¶
Modules for adding spiking behaviour to PyTorch models.
Module for converting an arbitrary activation function to a spiking equivalent. |
|
Module implementing a Lowpass filter. |
|
Module for taking the average across one dimension of a tensor. |
-
class
pytorch_spiking.
SpikingActivation
(activation, dt=0.001, initial_state=None, spiking_aware_training=True, return_sequences=True)[source]¶ Module for converting an arbitrary activation function to a spiking equivalent.
Neurons will spike at a rate proportional to the output of the base activation function. For example, if the activation function is outputting a value of 10, then the wrapped SpikingActivationCell will output spikes at a rate of 10Hz (i.e., 10 spikes per 1 simulated second, where 1 simulated second is equivalent to
1/dt
time steps). Each spike will have height1/dt
(so that the integral of the spiking output will be the same as the integral of the base activation output). Note that if the base activation is outputting a negative value then the spikes will have height-1/dt
. Multiple spikes per timestep are also possible, in which case the output will ben/dt
(wheren
is the number of spikes).When applying this layer to an input, make sure that the input has a time axis. The spiking output will be computed along the time axis. The number of simulation timesteps will depend on the length of that time axis. The number of timesteps does not need to be the same during training/evaluation/inference. In particular, it may be more efficient to use one timestep during training and multiple timesteps during inference (often with
spiking_aware_training=False
, andapply_during_training=False
on anyLowpass
layers).- Parameters
- activationcallable
Activation function to be converted to spiking equivalent.
- dtfloat
Length of time (in seconds) represented by one time step.
- initial_state
torch.Tensor
Initial spiking voltage state (should be an array with shape
(batch_size, n_neurons)
, with values between 0 and 1). Will use a uniform distribution if none is specified.- spiking_aware_trainingbool
If True (default), use the spiking activation function for the forward pass and the base activation function for the backward pass. If False, use the base activation function for the forward and backward pass during training.
- return_sequencesbool
Whether to return the full sequence of output spikes (default), or just the spikes on the last timestep.
-
forward
(inputs)[source]¶ Compute output spikes given inputs.
- Parameters
- inputs
torch.Tensor
Array of input values with shape
(batch_size, n_steps, n_neurons)
.
- inputs
- Returns
- outputs
torch.Tensor
Array of output spikes with shape
(batch_size, n_neurons)
ifreturn_sequences=False
else(batch_size, n_steps, n_neurons)
. Each element will have valuen/dt
, wheren
is the number of spikes emitted by that neuron on that time step.
- outputs
-
class
pytorch_spiking.
Lowpass
(tau, units, dt=0.001, apply_during_training=True, initial_level=None, return_sequences=True)[source]¶ Module implementing a Lowpass filter.
The initial filter state and filter time constants are both trainable parameters. However, if
apply_during_training=False
then the parameters are not part of the training loop, and so will never be updated.When applying this layer to an input, make sure that the input has a time axis.
- Parameters
- taufloat
Time constant of filter (in seconds).
- dtfloat
Length of time (in seconds) represented by one time step.
- apply_during_trainingbool
If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g.,
SpikingActivation
, since if the activations are not spiking during training then we often don’t need to filter them either).- level_initializer
torch.Tensor
Initializer for filter state.
- return_sequencesbool
Whether to return the full sequence of filtered output (default), or just the output on the last timestep.
Functions¶
Functional implementation of spiking layers.
Function for converting an arbitrary activation function to a spiking equivalent. |
-
class
pytorch_spiking.functional.
SpikingActivation
(*args, **kwargs)[source]¶ Function for converting an arbitrary activation function to a spiking equivalent.
Notes
We would not recommend calling this directly, use
pytorch_spiking.SpikingActivation
instead.-
static
forward
(ctx, inputs, activation, dt=0.001, initial_state=None, spiking_aware_training=True, return_sequences=False, training=False)[source]¶ Forward pass of SpikingActivation function.
- Parameters
- inputs
torch.Tensor
Array of input values with shape
(batch_size, n_steps, n_neurons)
.- activationcallable
Activation function to be converted to spiking equivalent.
- dtfloat
Length of time (in seconds) represented by one time step.
- initial_state
torch.Tensor
Initial spiking voltage state (should be an array with shape
(batch_size, n_neurons)
, with values between 0 and 1). Will use a uniform distribution if none is specified.- spiking_aware_trainingbool
If True (default), use the spiking activation function for the forward pass and the base activation function for the backward pass. If False, use the base activation function for the forward and backward pass during training.
- return_sequencesbool
Whether to return the last output in the output sequence (default), or the full sequence.
- trainingbool
Whether this function should be executed in training or evaluation mode (this only matters if
spiking_aware_training=False
).
- inputs
-
static