from collections import defaultdict, OrderedDict
import logging
import warnings
from nengo.exceptions import SimulationError, ValidationError
import numpy as np
from nengo_loihi.block import Probe
from nengo_loihi.compat import is_array, is_number, make_process_step
from nengo_loihi.discretize import (
decay_int,
LEARN_FRAC,
learn_overflow_bits,
overflow_signed,
scale_pes_errors,
shift,
Q_BITS,
U_BITS,
)
from nengo_loihi.validate import validate_model
logger = logging.getLogger(__name__)
[docs]class EmulatorInterface:
"""Software emulator for Loihi chip behaviour.
Parameters
----------
model : Model
Model specification that will be simulated.
seed : int, optional (Default: None)
A seed for all stochastic operations done in this simulator.
"""
strict = False
def __init__(self, model, seed=None):
self.closed = True
validate_model(model)
if seed is None:
seed = np.random.randint(2 ** 31 - 1)
self.seed = seed
logger.debug("EmulatorInterface seed: %d", seed)
self.rng = np.random.RandomState(self.seed)
self.block_info = BlockInfo(model.blocks)
self.inputs = list(model.inputs)
logger.debug("EmulatorInterface dtype: %s", self.block_info.dtype)
self.compartment = CompartmentState(self.block_info, strict=self.strict)
self.synapses = SynapseState(
self.block_info,
pes_error_scale=getattr(model, "pes_error_scale", 1.0),
strict=self.strict,
)
self.axons = AxonState(self.block_info)
self.probes = ProbeState(self.block_info, self.inputs, model.dt)
self.t = 0
self._chip2host_sent_steps = 0
self.closed = False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def close(self):
self.closed = True
# remove references to states to free memory (except probes)
self.block_info = None
self.inputs = None
self.compartment = None
self.synapses = None
self.axons = None
def chip2host(self, probes_receivers):
increment = 0
for probe, receiver in probes_receivers.items():
inc = self.probes.send(probe, self._chip2host_sent_steps, receiver)
increment = inc if increment == 0 else increment
assert inc == 0 or increment == inc
self._chip2host_sent_steps += increment
def host2chip(self, spikes, errors):
for spike_input, t, spike_idxs in spikes:
spike_input.add_spikes(t, spike_idxs)
self.synapses.update_pes_errors(errors)
[docs] def run_steps(self, steps):
"""Simulate for the given number of ``dt`` steps.
Parameters
----------
steps : int
Number of steps to run the simulation for.
"""
for _ in range(steps):
self.step()
[docs] def step(self):
"""Advance the simulation by 1 step (``dt`` seconds)."""
self.t += 1
self.compartment.advance_input()
self.synapses.inject_current(
self.t, self.inputs, self.axons, self.compartment.spiked
)
self.synapses.update_input(self.compartment.input)
self.synapses.update_traces(self.t, self.rng)
self.synapses.update_weights(self.t, self.rng)
self.compartment.update(self.rng)
self.probes.update(self.t, self.compartment)
def get_probe_output(self, probe):
return self.probes[probe]
[docs]class BlockInfo:
"""Provide information about all the LoihiBlocks in the model.
Attributes
----------
dtype : dtype
Datatype of the blocks. Either ``np.float32`` if the blocks are not
discretized or ``np.int32`` if they are. All blocks are the same.
blocks : list of LoihiBlock
List of all the blocks in the model.
n_compartments : int
Total number of compartments across all blocks.
slices : dict of {LoihiBlock: slice}
Maps each block to a slice for that block's compartments with
respect to all compartments. Used to slice into any array storing
data across all compartments.
"""
def __init__(self, blocks):
self.blocks = list(blocks)
self.slices = OrderedDict()
assert self.dtype in (np.float32, np.int32)
start_ix = end_ix = 0
for block in self.blocks:
end_ix += block.n_neurons
self.slices[block] = slice(start_ix, end_ix)
assert block.compartment.vth.dtype == self.dtype
assert block.compartment.bias.dtype == self.dtype
start_ix = end_ix
self.n_compartments = end_ix
@property
def dtype(self):
return self.blocks[0].compartment.vth.dtype
[docs]class IterableState:
"""Base class for aspects of the emulator state.
This class takes the name of a LoihiBlock attribute as the
``block_key`` and maps these objects to their parent blocks and slices.
Attributes
----------
dtype : dtype
Datatype of the state elements (given by the BlockInfo datatype).
block_map : dict of {item: block}
Maps an item (determined by ``block_key``) to the block
it belongs to.
n_compartments : int
The total number of neuron compartments (given by BlockInfo).
slices : dict of {item: slice}
Maps an item to the ``block_info.slice`` for the block
it belongs to.
strict : bool (Default: True)
Whether "undesired" chip effects (ex. overflow) raise errors (``True``)
or whether they only raise warnings (``False``).
"""
def __init__(self, block_info, block_key, strict=True):
self.n_compartments = block_info.n_compartments
self.dtype = block_info.dtype
self.strict = strict
blocks_items = list(self._blocks_items(block_info.blocks, block_key))
self.block_map = OrderedDict((item, block) for block, item in blocks_items)
self.slices = OrderedDict(
(item, block_info.slices[block]) for block, item in blocks_items
)
@staticmethod
def _blocks_items(blocks, block_key):
for block in blocks:
if block_key == "compartment":
# one item per block
yield block, getattr(block, block_key)
else:
# multiple items per block (attribute is iterable)
for item in getattr(block, block_key):
yield block, item
def error(self, msg):
if self.strict:
raise SimulationError(msg)
else:
warnings.warn(msg)
def items(self):
return self.slices.items()
[docs]class CompartmentState(IterableState):
"""State representing the Compartments of all blocks."""
MAX_DELAY = 1 # delay not yet implemented
def __init__(self, block_info, strict=True):
super(CompartmentState, self).__init__(block_info, "compartment", strict=strict)
# Initialize NumPy arrays to store compartment-related data
self.input = np.zeros((self.MAX_DELAY, self.n_compartments), dtype=self.dtype)
self.current = np.zeros(self.n_compartments, dtype=self.dtype)
self.voltage = np.zeros(self.n_compartments, dtype=self.dtype)
self.spiked = np.zeros(self.n_compartments, dtype=bool)
self.spike_count = np.zeros(self.n_compartments, dtype=np.int32)
self.ref_count = np.zeros(self.n_compartments, dtype=np.int32)
self.decay_u = np.full(self.n_compartments, np.nan, dtype=self.dtype)
self.decay_v = np.full(self.n_compartments, np.nan, dtype=self.dtype)
self.scale_u = np.ones(self.n_compartments, dtype=self.dtype)
self.scale_v = np.ones(self.n_compartments, dtype=self.dtype)
self.vth = np.full(self.n_compartments, np.nan, dtype=self.dtype)
self.vmin = np.full(self.n_compartments, np.nan, dtype=self.dtype)
self.vmax = np.full(self.n_compartments, np.nan, dtype=self.dtype)
self.bias = np.full(self.n_compartments, np.nan, dtype=self.dtype)
self.ref = np.full(self.n_compartments, np.nan, dtype=self.dtype)
# Fill in arrays with parameters from CompartmentSegments
for compartment, sl in self.items():
self.decay_u[sl] = compartment.decay_u
self.decay_v[sl] = compartment.decay_v
if compartment.scale_u:
self.scale_u[sl] = compartment.decay_u
if compartment.scale_v:
self.scale_v[sl] = compartment.decay_v
self.vth[sl] = compartment.vth
self.vmin[sl] = compartment.vmin
self.vmax[sl] = compartment.vmax
self.bias[sl] = compartment.bias
self.ref[sl] = compartment.refract_delay
assert not np.any(np.isnan(self.decay_u))
assert not np.any(np.isnan(self.decay_v))
assert not np.any(np.isnan(self.vth))
assert not np.any(np.isnan(self.vmin))
assert not np.any(np.isnan(self.vmax))
assert not np.any(np.isnan(self.bias))
assert not np.any(np.isnan(self.ref))
if self.dtype == np.int32:
assert (self.scale_u == 1).all()
assert (self.scale_v == 1).all()
self._decay_current = lambda x, u: decay_int(x, self.decay_u, offset=1) + u
self._decay_voltage = lambda x, u: decay_int(x, self.decay_v) + u
def overflow(x, bits, name=None):
_, o = overflow_signed(x, bits=bits, out=x)
if np.any(o):
self.error("Overflow" + (" in %s" % name if name else ""))
elif self.dtype == np.float32:
def decay_float(x, u, d, s):
return (1 - d) * x + s * u
self._decay_current = lambda x, u: decay_float(
x, u, d=self.decay_u, s=self.scale_u
)
self._decay_voltage = lambda x, u: decay_float(
x, u, d=self.decay_v, s=self.scale_v
)
def overflow(x, bits, name=None):
pass # do not do overflow in floating point
else:
raise ValidationError(
"dtype %r not supported" % self.dtype, attr="dtype", obj=block_info
)
self._overflow = overflow
self.noise = NoiseState(block_info)
def advance_input(self):
self.input[:-1] = self.input[1:]
self.input[-1] = 0
def update(self, rng):
noise = self.noise.sample(rng)
q0 = self.input[0, :]
q0[~(self.noise.target_u)] += noise[~(self.noise.target_u)]
self._overflow(q0, bits=Q_BITS, name="q0")
self.current[:] = self._decay_current(self.current, q0)
self._overflow(self.current, bits=U_BITS, name="current")
u2 = self.current + self.bias
u2[self.noise.target_u] += noise[self.noise.target_u]
self._overflow(u2, bits=U_BITS, name="u2")
self.voltage[:] = self._decay_voltage(self.voltage, u2)
# We have not been able to create V overflow on the chip, so we do
# not include it here. See github.com/nengo/nengo-loihi/issues/130
# self.overflow(self.v, bits=V_BIT, name="V")
np.clip(self.voltage, self.vmin, self.vmax, out=self.voltage)
self.voltage[self.ref_count > 0] = 0
# TODO^: don't zero voltage in case neuron is saving overshoot
self.spiked[:] = self.voltage > self.vth
self.voltage[self.spiked] = 0
self.ref_count[self.spiked] = self.ref[self.spiked]
# decrement ref_count
np.clip(self.ref_count - 1, 0, None, out=self.ref_count)
self.spike_count[self.spiked] += 1
[docs]class NoiseState(IterableState):
"""State representing the noise parameters for all compartments."""
def __init__(self, block_info):
super(NoiseState, self).__init__(block_info, "compartment")
self.enabled = np.full(self.n_compartments, np.nan, dtype=bool)
self.exp = np.full(self.n_compartments, np.nan, dtype=self.dtype)
self.mant_offset = np.full(self.n_compartments, np.nan, dtype=self.dtype)
self.target_u = np.full(self.n_compartments, np.nan, dtype=bool)
# Fill in arrays with parameters from Compartment
for compartment, sl in self.items():
self.enabled[sl] = compartment.enable_noise
self.exp[sl] = compartment.noise_exp
self.mant_offset[sl] = compartment.noise_offset
self.target_u[sl] = compartment.noise_at_membrane
if self.dtype == np.int32:
# TODO: if we could do this mult with shifts, it'd be faster, but
# numpy has no function taking a vector of positive/negative shifts
self.mult = np.where(self.enabled, 2.0 ** (self.exp - 7), 0)
self.mant_offset *= 64
def uniform(rng, n=self.n_compartments):
return rng.randint(-127, 128, size=n, dtype=np.int32)
elif self.dtype == np.float32:
self.mult = np.where(self.enabled, 10.0 ** self.exp, 0)
def uniform(rng, n=self.n_compartments):
return rng.uniform(-1, 1, size=n).astype(np.float32)
else:
raise ValidationError(
"dtype %r not supported" % self.dtype, attr="dtype", obj=block_info
)
assert not np.any(np.isnan(self.enabled))
assert not np.any(np.isnan(self.exp))
assert not np.any(np.isnan(self.mant_offset))
assert not np.any(np.isnan(self.target_u))
assert not np.any(np.isnan(self.mult))
self._uniform = uniform
def sample(self, rng):
x = self._uniform(rng)
return ((x + self.mant_offset) * self.mult).astype(self.dtype)
[docs]class SynapseState(IterableState):
"""State representing all synapses.
Attributes
----------
pes_error_scale : float
Scaling for the errors of PES learning rules.
pes_errors : {Synapse: ndarray(n_neurons / 2)}
Maps synapse to PES learning rule errors for those synapses.
spikes_in : {Synapse: list}
Maps synapse to a queue of input spikes targeting those synapses.
traces : {Synapse: ndarray(Synapse.n_axons)}
Maps synapse to trace values for each of their axons.
trace_spikes : {Synapse: set}
Maps synapse to a queue of input spikes waiting to be added to those
synapse traces.
"""
def __init__(self, block_info, pes_error_scale=1.0, strict=True): # noqa: C901
super(SynapseState, self).__init__(block_info, "synapses", strict=strict)
self.pes_error_scale = pes_error_scale
self.spikes_in = OrderedDict()
self.traces = OrderedDict()
self.trace_spikes = OrderedDict()
self.pes_errors = OrderedDict()
for synapse in self.slices:
n = synapse.n_axons
self.spikes_in[synapse] = []
if synapse.learning:
self.traces[synapse] = np.zeros(n, dtype=self.dtype)
self.trace_spikes[synapse] = set()
self.pes_errors[synapse] = np.zeros(
self.block_map[synapse].n_neurons // 2, dtype=self.dtype
)
# ^ Currently, PES learning only happens on Nodes, where we
# have pairs of on/off neurons. Therefore, the number of error
# dimensions is half the number of neurons.
if self.dtype == np.int32:
def stochastic_round(
x, dtype=self.dtype, rng=None, clip=None, name="values"
):
x_sign = np.sign(x).astype(dtype)
x_frac, x_int = np.modf(np.abs(x))
p = rng.rand(*x.shape)
y = x_int.astype(dtype) + (x_frac > p)
if clip is not None:
q = y > clip
if np.any(q):
warnings.warn("Clipping %s" % name)
y[q] = clip
return x_sign * y
def trace_round(x, rng=None):
return stochastic_round(x, rng=rng, clip=127, name="synapse trace")
def weight_update(synapse, delta_ws, rng=None):
synapse_cfg = synapse.synapse_cfg
wgt_exp = synapse_cfg.real_weight_exp
shift_bits = synapse_cfg.shift_bits
overflow = learn_overflow_bits(n_factors=2)
for w, delta_w in zip(synapse.weights, delta_ws):
product = shift(
delta_w * synapse._lr_int,
LEARN_FRAC + synapse._lr_exp - overflow,
)
learn_w = shift(w, LEARN_FRAC - wgt_exp) + product
learn_w[:] = stochastic_round(
learn_w * 2 ** (-LEARN_FRAC - shift_bits),
clip=2 ** (8 - shift_bits) - 1,
rng=rng,
name="learning weights",
)
w[:] = np.left_shift(learn_w, wgt_exp + shift_bits)
elif self.dtype == np.float32:
def trace_round(x, rng=None):
return x # no rounding
def weight_update(synapse, delta_ws, rng=None):
for w, delta_w in zip(synapse.weights, delta_ws):
w += synapse.learning_rate * delta_w
else:
raise ValidationError(
"dtype %r not supported" % self.dtype, attr="dtype", obj=block_info
)
self._trace_round = trace_round
self._weight_update = weight_update
def inject_current(self, t, spike_inputs, all_axons, spiked):
# --- clear spikes going in to each synapse
for spike_queue in self.spikes_in.values():
spike_queue.clear()
# --- inputs pass spikes to synapses
if t >= 2: # input spikes take one time-step to arrive
for spike_input in spike_inputs:
compartment_idxs = spike_input.spike_idxs(t - 1)
for axon in spike_input.axons:
spikes = axon.map_spikes(compartment_idxs)
self.spikes_in[axon.target].extend(
s for s in spikes if s is not None
)
# --- axons pass spikes to synapses
for axon, a_idx in all_axons.items():
compartment_idxs = spiked[a_idx].nonzero()[0]
spikes = axon.map_spikes(compartment_idxs)
self.spikes_in[axon.target].extend(s for s in spikes if s is not None)
def update_input(self, input):
for synapse, s_slice in self.items():
qb = input[:, s_slice]
for spike in self.spikes_in[synapse]:
base = synapse.axon_compartment_base(spike.axon_id)
if base is None:
continue
weights, indices = synapse.axon_weights_indices(
spike.axon_id, atom=spike.atom
)
qb[0, base + indices] += weights
def update_pes_errors(self, errors):
# TODO: these are sent every timestep, but learning only happens every
# `tepoch * 2**learn_k` timesteps (see Synapse). Need to average.
for pes_errors in self.pes_errors.values():
pes_errors[:] = 0
for synapse, _, e in errors:
pes_errors = self.pes_errors[synapse]
assert pes_errors.shape == e.shape
pes_errors += scale_pes_errors(e, scale=self.pes_error_scale)
def update_weights(self, t, rng):
for synapse, pes_error in self.pes_errors.items():
if t % synapse.learn_epoch == 0:
trace = self.traces[synapse]
e = np.hstack([-pes_error, pes_error])
delta_w = np.outer(trace, e)
self._weight_update(synapse, delta_w, rng=rng)
def update_traces(self, t, rng):
for synapse in self.traces:
trace_spikes = self.trace_spikes.get(synapse, None)
if trace_spikes is not None:
for spike in self.spikes_in[synapse]:
if spike.axon_id in trace_spikes:
self.error("Synaptic trace spikes lost")
trace_spikes.add(spike.axon_id)
trace = self.traces.get(synapse, None)
if trace is not None and t % synapse.train_epoch == 0:
tau = synapse.tracing_tau
decay = np.exp(-synapse.train_epoch / tau)
trace1 = decay * trace
trace1[list(trace_spikes)] += synapse.tracing_mag
trace[:] = self._trace_round(trace1, rng=rng)
trace_spikes.clear()
[docs]class AxonState(IterableState):
"""State representing all (output) Axons."""
def __init__(self, block_info):
super(AxonState, self).__init__(block_info, "axons")
[docs]class ProbeState:
"""State representing all probes.
Attributes
----------
dt : float
Time constant of the Emulator.
filters : {nengo_loihi.Probe: function}
Maps Probes to the filtering function for that probe.
filter_pos : {nengo_loihi.Probe: int}
Maps Probes to the position of their filter in the data.
block_probes : {nengo_loihi.Probe: slice}
Maps Probes to the BlockInfo slice for the block they are probing.
input_probes : {nengo_loihi.Probe: SpikeInput}
Maps Probes to the SpikeInput that they are probing.
"""
def __init__(self, block_info, inputs, dt):
self.dt = dt
self.probes = OrderedDict()
for block in block_info.blocks:
for probe in block.probes:
self.probes[probe] = block_info.slices[block]
self.filters = {}
self.filter_pos = {}
for probe, sl in self.probes.items():
if probe.synapse is not None:
if probe.weights is None or is_number(probe.weights):
size = sl.stop - sl.start
else:
assert is_array(probe.weights) and probe.weights.ndim == 2
size = probe.weights.shape[1]
self.filters[probe] = make_process_step(
probe.synapse,
shape_in=(size,),
shape_out=(size,),
dt=self.dt,
rng=None,
dtype=np.float32,
)
self.filter_pos[probe] = 0
self.outputs = defaultdict(list)
def __getitem__(self, probe):
assert isinstance(probe, Probe)
out = np.asarray(self.outputs[probe], dtype=np.float32)
out = out if probe.weights is None else np.dot(out, probe.weights)
return self._filter(probe, out) if probe in self.filters else out
def _filter(self, probe, data):
dt = self.dt
i = self.filter_pos[probe]
step = self.filters[probe]
filt_data = np.zeros_like(data)
for k, x in enumerate(data):
filt_data[k] = step((i + k) * dt, x)
self.filter_pos[probe] = i + k
return filt_data
[docs] def send(self, probe, already_sent, receiver):
"""Send probed data to the receiver node.
Returns
-------
steps : int
The number of steps sent to the receiver.
"""
x = self.outputs[probe][already_sent:]
if len(x) > 0:
if probe.weights is not None:
x = np.dot(x, probe.weights)
for j, xx in enumerate(x):
receiver.receive(self.dt * (already_sent + j + 2), xx)
return len(x)
def update(self, t, compartment):
for probe, sl in self.probes.items():
p_slice = probe.slice
assert hasattr(compartment, probe.key)
output = getattr(compartment, probe.key)[sl][p_slice].copy()
self.outputs[probe].append(output)