API reference

Modules

Modules for adding spiking behaviour to PyTorch models.

pytorch_spiking.SpikingActivation

Module for converting an arbitrary activation function to a spiking equivalent.

pytorch_spiking.Lowpass

Module implementing a Lowpass filter.

pytorch_spiking.TemporalAvgPool

Module for taking the average across one dimension of a tensor.

class pytorch_spiking.SpikingActivation(activation, dt=0.001, initial_state=None, spiking_aware_training=True, return_sequences=True)[source]

Module for converting an arbitrary activation function to a spiking equivalent.

Neurons will spike at a rate proportional to the output of the base activation function. For example, if the activation function is outputting a value of 10, then the wrapped SpikingActivationCell will output spikes at a rate of 10Hz (i.e., 10 spikes per 1 simulated second, where 1 simulated second is equivalent to 1/dt time steps). Each spike will have height 1/dt (so that the integral of the spiking output will be the same as the integral of the base activation output). Note that if the base activation is outputting a negative value then the spikes will have height -1/dt. Multiple spikes per timestep are also possible, in which case the output will be n/dt (where n is the number of spikes).

When applying this layer to an input, make sure that the input has a time axis. The spiking output will be computed along the time axis. The number of simulation timesteps will depend on the length of that time axis. The number of timesteps does not need to be the same during training/evaluation/inference. In particular, it may be more efficient to use one timestep during training and multiple timesteps during inference (often with spiking_aware_training=False, and apply_during_training=False on any Lowpass layers).

Parameters
activationcallable

Activation function to be converted to spiking equivalent.

dtfloat

Length of time (in seconds) represented by one time step.

initial_statetorch.Tensor

Initial spiking voltage state (should be an array with shape (batch_size, n_neurons), with values between 0 and 1). Will use a uniform distribution if none is specified.

spiking_aware_trainingbool

If True (default), use the spiking activation function for the forward pass and the base activation function for the backward pass. If False, use the base activation function for the forward and backward pass during training.

return_sequencesbool

Whether to return the full sequence of output spikes (default), or just the spikes on the last timestep.

forward(inputs)[source]

Compute output spikes given inputs.

Parameters
inputstorch.Tensor

Array of input values with shape (batch_size, n_steps, n_neurons).

Returns
outputstorch.Tensor

Array of output spikes with shape (batch_size, n_neurons) if return_sequences=False else (batch_size, n_steps, n_neurons). Each element will have value n/dt, where n is the number of spikes emitted by that neuron on that time step.

class pytorch_spiking.Lowpass(tau, units, dt=0.001, apply_during_training=True, initial_level=None, return_sequences=True)[source]

Module implementing a Lowpass filter.

The initial filter state and filter time constants are both trainable parameters. However, if apply_during_training=False then the parameters are not part of the training loop, and so will never be updated.

When applying this layer to an input, make sure that the input has a time axis.

Parameters
taufloat

Time constant of filter (in seconds).

dtfloat

Length of time (in seconds) represented by one time step.

apply_during_trainingbool

If False, this layer will effectively be ignored during training (this often makes sense in concert with the swappable training behaviour in, e.g., SpikingActivation, since if the activations are not spiking during training then we often don’t need to filter them either).

level_initializertorch.Tensor

Initializer for filter state.

return_sequencesbool

Whether to return the full sequence of filtered output (default), or just the output on the last timestep.

forward(inputs)[source]

Apply filter to inputs.

Parameters
inputstorch.Tensor

Array of input values with shape (batch_size, n_steps, units).

Returns
outputstorch.Tensor

Array of output spikes with shape (batch_size, units) if return_sequences=False else (batch_size, n_steps, units).

class pytorch_spiking.TemporalAvgPool(dim=1)[source]

Module for taking the average across one dimension of a tensor.

Parameters
dimint, optional

The dimension to average across. Defaults to the second dimension (dim=1), which is typically the time dimension (for tensors that have a time dimension).

forward(inputs)[source]

Apply average pooling to inputs.

Parameters
inputstorch.Tensor

Array of input values with shape (batch_size, n_steps, ...).

Returns
outputstorch.Tensor

Array of output values with shape (batch_size, ...). The time dimension is fully averaged and removed.

Functions

Functional implementation of spiking layers.

pytorch_spiking.functional.SpikingActivation

Function for converting an arbitrary activation function to a spiking equivalent.

class pytorch_spiking.functional.SpikingActivation(*args, **kwargs)[source]

Function for converting an arbitrary activation function to a spiking equivalent.

Notes

We would not recommend calling this directly, use pytorch_spiking.SpikingActivation instead.

static forward(ctx, inputs, activation, dt=0.001, initial_state=None, spiking_aware_training=True, return_sequences=False, training=False)[source]

Forward pass of SpikingActivation function.

Parameters
inputstorch.Tensor

Array of input values with shape (batch_size, n_steps, n_neurons).

activationcallable

Activation function to be converted to spiking equivalent.

dtfloat

Length of time (in seconds) represented by one time step.

initial_statetorch.Tensor

Initial spiking voltage state (should be an array with shape (batch_size, n_neurons), with values between 0 and 1). Will use a uniform distribution if none is specified.

spiking_aware_trainingbool

If True (default), use the spiking activation function for the forward pass and the base activation function for the backward pass. If False, use the base activation function for the forward and backward pass during training.

return_sequencesbool

Whether to return the last output in the output sequence (default), or the full sequence.

trainingbool

Whether this function should be executed in training or evaluation mode (this only matters if spiking_aware_training=False).

static backward(ctx, grad_output)[source]

Backward pass of SpikingActivation function.