import logging
import warnings
from nengo.exceptions import BuildError
import numpy as np
from nengo_loihi.block import SynapseConfig
from nengo_loihi.compat import is_iterable
from nengo_loihi.nxsdk_obfuscation import d
VTH_MAN_MAX = d(b"MTMxMDcx", int)
VTH_EXP = d(b"Ng==", int)
VTH_MAX = VTH_MAN_MAX * 2 ** VTH_EXP
BIAS_MAN_MAX = d(b"NDA5NQ==", int)
BIAS_EXP_MAX = d(b"Nw==", int)
BIAS_MAX = BIAS_MAN_MAX * 2 ** BIAS_EXP_MAX
# number of bits for synapse accumulator
Q_BITS = d(b"MjE=", int)
# number of bits for compartment input (u)
U_BITS = d(b"MjM=", int)
# number of bits in learning accumulator (not incl. sign)
LEARN_BITS = d(b"MTU=", int)
# extra least-significant bits added to weights for learning
LEARN_FRAC = d(b"Nw==", int)
logger = logging.getLogger(__name__)
def array_to_int(array, value):
assert array.dtype == np.float32
new = np.round(value).astype(np.int32)
array.dtype = np.int32
array[:] = new
def learn_overflow_bits(n_factors):
"""Compute number of bits with which learning will overflow.
Parameters
----------
n_factors : int
The number of learning factors (pre/post terms in the learning rule).
"""
factor_bits = 7 # number of bits per factor
mantissa_bits = 3 # number of bits for learning rate mantissa
return factor_bits * n_factors + mantissa_bits - LEARN_BITS
def overflow_signed(x, bits=7, out=None):
"""Compute overflow on an array of signed integers.
For example, the Loihi chip uses 23 bits plus sign to represent U.
We can store them as 32-bit integers, and use this function to compute
how they would overflow if we only had 23 bits plus sign.
Parameters
----------
x : array
Integer values for which to compute values after overflow.
bits : int
Number of bits, not including sign, to compute overflow for.
out : array, optional (Default: None)
Output array to put computed overflow values in.
Returns
-------
y : array
Values of x overflowed as would happen with limited bit representation.
overflowed : array
Boolean array indicating which values of ``x`` actually overflowed.
"""
if out is None:
out = np.array(x)
else:
assert isinstance(out, np.ndarray)
out[:] = x
assert np.issubdtype(out.dtype, np.integer)
x1 = np.array(1, dtype=out.dtype)
smask = np.left_shift(x1, bits) # mask for the sign bit (2**bits)
xmask = smask - 1 # mask for all bits <= `bits`
# find whether we've overflowed
overflowed = (out < -smask) | (out >= smask)
zmask = out & smask # if `out` has negative sign bit, == 2**bits
out &= xmask # mask out all bits > `bits`
out -= zmask # subtract 2**bits if negative sign bit
return out, overflowed
def vth_to_manexp(vth):
exp = VTH_EXP * np.ones(vth.shape, dtype=np.int32)
man = np.round(vth / 2 ** exp).astype(np.int32)
assert (man > 0).all()
assert (man <= VTH_MAN_MAX).all()
return man, exp
def bias_to_manexp(bias):
r = np.maximum(np.abs(bias) / BIAS_MAN_MAX, 1)
exp = np.ceil(np.log2(r)).astype(np.int32)
man = np.round(bias / 2 ** exp).astype(np.int32)
assert (exp >= 0).all()
assert (exp <= BIAS_EXP_MAX).all()
assert (np.abs(man) <= BIAS_MAN_MAX).all()
return man, exp
def tracing_mag_int_frac(mag):
"""Split trace magnitude into integer and fractional components for chip"""
mag_int = int(mag)
mag_frac = int(d(b"MTI4", int) * (mag - mag_int))
return mag_int, mag_frac
def decay_int(x, decay, bits=None, offset=0, out=None):
"""Decay integer values using a decay constant.
The decayed value is given by::
sign(x) * floor(abs(x) * (2**bits - offset - decay) / 2**bits)
"""
if out is None:
out = np.zeros_like(x)
if bits is None:
bits = d(b"MTI=", int)
r = (2 ** bits - offset - np.asarray(decay)).astype(np.int64)
np.right_shift(np.abs(x) * r, bits, out=out)
return np.sign(x) * out
def decay_magnitude(decay, x0=2 ** 21, bits=12, offset=0):
"""Estimate the sum of the series of rounded integer decays of ``x0``.
This can be used to estimate the total input current or voltage (summed
over time) caused by an input of magnitude ``x0``. In real values, this is
easy to calculate as the integral of an exponential. In integer values,
we need to account for the rounding down that happens each time the decay
is computed.
Specifically, we estimate the sum of the series::
x_i = floor(r x_{i-1})
where ``r = (2**bits - offset - decay)``.
To simulate the effects of rounding in decay, we subtract an expected loss
due to rounding (``q``) each iteration. Our estimated series is therefore::
y_i = r * y_{i-1} - q
= r^i * x_0 - sum_k^{i-1} q * r^k
"""
# q: Expected loss per time step (found by empirical simulations). If the
# value being rounded down were uniformly distributed between 0 and 1, this
# should be 0.5 exactly, but empirically this does not appear to be the
# case and this value is better (see `test_decay_magnitude`).
q = 0.494
r = (2 ** bits - offset - np.asarray(decay)) / 2 ** bits # decay ratio
n = -(np.log1p(x0 * (1 - r) / q)) / np.log(r) # solve y_n = 0 for n
# principal_sum = (1./x0) sum_i^n x0 * r^i
# loss_sum = (1./x0) sum_i^n sum_k^{i-1} q * r^k
principal_sum = (1 - r ** (n + 1)) / (1 - r)
loss_sum = q / ((1 - r) * x0) * (n + 1 - (1 - r ** (n + 1)) / (1 - r))
return principal_sum - loss_sum
def scale_pes_errors(error, scale=1.0):
"""Scale PES errors based on a scaling factor, round and clip."""
error = scale * error
error = np.round(error).astype(np.int32)
max_err = d(b"MTI3", int)
q = error > max_err
if np.any(q):
warnings.warn(
"Received PES error greater than chip max (%0.2e). "
"Consider changing `Model.pes_error_scale`." % (max_err / scale,)
)
logger.debug(
"PES error %0.2e > %0.2e (chip max)", np.max(error) / scale, max_err / scale
)
error[q] = max_err
q = error < -max_err
if np.any(q):
warnings.warn(
"Received PES error less than chip min (%0.2e). "
"Consider changing `Model.pes_error_scale`." % (-max_err / scale,)
)
logger.debug(
"PES error %0.2e < %0.2e (chip min)",
np.min(error) / scale,
-max_err / scale,
)
error[q] = -max_err
return error
def shift(x, s, **kwargs):
if s < 0:
return np.right_shift(x, -s, **kwargs)
else:
return np.left_shift(x, s, **kwargs)
[docs]def discretize_model(model):
"""Discretize a `.Model` in-place.
Turns a floating-point `.Model` into a discrete (integer) model
appropriate for Loihi.
Parameters
----------
model : `.Model`
The model to discretize.
"""
for block in model.blocks:
discretize_block(block)
[docs]def discretize_block(block):
"""Discretize a `.LoihiBlock` in-place.
Turns a floating-point `.LoihiBlock` into a discrete (integer)
block appropriate for Loihi.
Parameters
----------
block : `.LoihiBlock`
The block to discretize.
"""
w_maxs = [s.max_abs_weight() for s in block.synapses]
w_max = max(w_maxs) if len(w_maxs) > 0 else 0
p = discretize_compartment(block.compartment, w_max)
for synapse in block.synapses:
discretize_synapse(synapse, w_max, p["w_scale"], p["w_exp"])
for probe in block.probes:
discretize_probe(probe, p["v_scale"][0])
[docs]def discretize_compartment(comp, w_max):
"""Discretize a `.Compartment` in-place.
Turns a floating-point `.Compartment` into a discrete (integer)
block appropriate for Loihi.
Parameters
----------
comp : `.Compartment`
The compartment to discretize.
w_max : float
The largest connection weight in the `.LoihiBlock` containing
``comp``. Used to set several scaling factors.
"""
# --- discretize decay_u and decay_v
# subtract 1 from decay_u here because it gets added back by the chip
decay_u = comp.decay_u * d(b"NDA5NQ==", int) - 1
array_to_int(comp.decay_u, np.clip(decay_u, 0, d(b"NDA5NQ==", int)))
array_to_int(comp.decay_v, comp.decay_v * d(b"NDA5NQ==", int))
# Compute factors for current and voltage decay. These factors
# counteract the fact that for longer decays, the current (or voltage)
# created by a single spike has a larger integral.
u_infactor = (
1.0 / decay_magnitude(comp.decay_u, x0=d(b"MjA5NzE1Mg==", int), offset=1)
if comp.scale_u
else np.ones(comp.decay_u.shape)
)
v_infactor = (
1.0 / decay_magnitude(comp.decay_v, x0=d(b"MjA5NzE1Mg==", int))
if comp.scale_v
else np.ones(comp.decay_v.shape)
)
comp.scale_u = False
comp.scale_v = False
# --- discretize weights and vth
# To avoid overflow, we can either lower vth_max or lower w_exp_max.
# Lowering vth_max is more robust, but has the downside that it may
# force smaller w_exp on connections than necessary, potentially
# leading to lost weight bits (see discretize_weights).
# Lowering w_exp_max can let us keep vth_max higher, but overflow
# is still be possible on connections with many small inputs (uncommon)
vth_max = VTH_MAX
w_exp_max = 0
b_max = np.abs(comp.bias).max()
w_exp = 0
if w_max > 1e-8:
w_scale = d(b"MjU1", float) / w_max
s_scale = 1.0 / (u_infactor * v_infactor)
for w_exp in range(w_exp_max, d(b"LTg=", int), d(b"LTE=", int)):
v_scale = s_scale * w_scale * SynapseConfig.get_scale(w_exp)
b_scale = v_scale * v_infactor
vth = np.round(comp.vth * v_scale)
bias = np.round(comp.bias * b_scale)
if (vth <= vth_max).all() and (np.abs(bias) <= BIAS_MAX).all():
break
else:
raise BuildError("Could not find appropriate weight exponent")
elif b_max > 1e-8:
b_scale = BIAS_MAX / b_max
while b_scale * b_max > 1:
v_scale = b_scale / v_infactor
w_scale = b_scale * u_infactor / SynapseConfig.get_scale(w_exp)
vth = np.round(comp.vth * v_scale)
bias = np.round(comp.bias * b_scale)
if np.all(vth <= vth_max):
break
b_scale /= 2.0
else:
raise BuildError("Could not find appropriate bias scaling")
else:
# reduce vth_max in this case to avoid overflow since we're setting
# all vth to vth_max (esp. in learning with zeroed initial weights)
vth_max = min(vth_max, 2 ** Q_BITS - 1)
v_scale = np.array([vth_max / (comp.vth.max() + 1)])
vth = np.round(comp.vth * v_scale)
b_scale = v_scale * v_infactor
bias = np.round(comp.bias * b_scale)
w_scale = v_scale * v_infactor * u_infactor / SynapseConfig.get_scale(w_exp)
vth_man, vth_exp = vth_to_manexp(vth)
array_to_int(comp.vth, vth_man * 2 ** vth_exp)
bias_man, bias_exp = bias_to_manexp(bias)
array_to_int(comp.bias, bias_man * 2 ** bias_exp)
# --- noise
assert (v_scale[0] == v_scale).all()
enable_noise = np.any(comp.enable_noise)
noise_exp = np.round(np.log2(10.0 ** comp.noise_exp * v_scale[0]))
if enable_noise and noise_exp < d(b"MQ==", int):
warnings.warn("Noise amplitude falls below lower limit")
enable_noise = False
if enable_noise and noise_exp > d(b"MjM=", int):
warnings.warn("Noise amplitude exceeds upper limit (%d > 23)" % (noise_exp,))
comp.noise_exp = int(np.clip(noise_exp, d(b"MQ==", int), d(b"MjM=", int)))
comp.noise_offset = int(np.round(2 * comp.noise_offset))
# --- vmin and vmax
assert (v_scale[0] == v_scale).all()
vmin = v_scale[0] * comp.vmin
vmax = v_scale[0] * comp.vmax
vmine = np.clip(np.round(np.log2(-vmin + 1)), 0, 2 ** 5 - 1)
comp.vmin = -(2 ** vmine) + 1
vmaxe = np.clip(np.round((np.log2(vmax + 1) - 9) * 0.5), 0, 2 ** 3 - 1)
comp.vmax = 2 ** (9 + 2 * vmaxe) - 1
return dict(w_max=w_max, w_scale=w_scale, w_exp=w_exp, v_scale=v_scale)
[docs]def discretize_synapse(synapse, w_max, w_scale, w_exp):
"""Discretize a `.Synapse` in-place.
Turns a floating-point `.Synapse` into a discrete (integer)
block appropriate for Loihi.
Parameters
----------
synapse : `.Synapse`
The synapse to discretize.
w_max : float
The largest connection weight in the `.LoihiBlock` containing
``synapse``. Used to scale weights appropriately.
w_scale : float
Connection weight scaling factor. Usually computed by
`.discretize_compartment`.
w_exp : float
Exponent on the connection weight scaling factor. Usually computed by
`.discretize_compartment`.
"""
w_max_i = synapse.max_abs_weight()
if synapse.learning:
w_exp2 = synapse.learning_wgt_exp
dw_exp = w_exp - w_exp2
elif w_max_i > 1e-16:
dw_exp = int(np.floor(np.log2(w_max / w_max_i)))
assert dw_exp >= 0
w_exp2 = max(w_exp - dw_exp, d(b"LTY=", int))
else:
w_exp2 = d(b"LTY=", int)
dw_exp = w_exp - w_exp2
synapse.format(weight_exp=w_exp2)
for w, idxs in zip(synapse.weights, synapse.indices):
ws = w_scale[idxs] if is_iterable(w_scale) else w_scale
array_to_int(w, discretize_weights(synapse.synapse_cfg, w * ws * 2 ** dw_exp))
# discretize learning
if synapse.learning:
synapse.tracing_tau = int(np.round(synapse.tracing_tau))
if is_iterable(w_scale):
assert np.all(w_scale == w_scale[0])
w_scale_i = w_scale[0] if is_iterable(w_scale) else w_scale
# incorporate weight scale and difference in weight exponents
# to learning rate, since these affect speed at which we learn
ws = w_scale_i * 2 ** dw_exp
synapse.learning_rate *= ws
# Loihi down-scales learning factors based on the number of
# overflow bits. Increasing learning rate maintains true rate.
synapse.learning_rate *= 2 ** learn_overflow_bits(2)
# TODO: Currently, Loihi learning rate fixed at 2**-7.
# We should explore adjusting it for better performance.
lscale = 2 ** -7 / synapse.learning_rate
synapse.learning_rate *= lscale
synapse.tracing_mag /= lscale
# discretize learning rate into mantissa and exponent
lr_exp = int(np.floor(np.log2(synapse.learning_rate)))
lr_int = int(np.round(synapse.learning_rate * 2 ** (-lr_exp)))
synapse.learning_rate = lr_int * 2 ** lr_exp
synapse._lr_int = lr_int
synapse._lr_exp = lr_exp
assert lr_exp >= d(b"LTc=", int)
# discretize tracing mag into integer and fractional components
mag_int, mag_frac = tracing_mag_int_frac(synapse.tracing_mag)
if mag_int > d(b"MTI3", int):
warnings.warn(
"Trace increment exceeds upper limit "
"(learning rate may be too large)"
)
mag_int = d(b"MTI3", int)
mag_frac = d(b"MTI3", int)
synapse.tracing_mag = mag_int + mag_frac / d(b"MTI4", float)
def discretize_weights(
synapse_cfg, w, dtype=np.int32, lossy_shift=True, check_result=True
):
"""Takes weights and returns their quantized values with weight_exp.
The actual weight to be put on the chip is this returned value
divided by the ``scale`` attribute.
Parameters
----------
w : float ndarray
Weights to be discretized, in the range -255 to 255.
dtype : np.dtype, optional (Default: np.int32)
Data type for discretized weights.
lossy_shift : bool, optional (Default: True)
Whether to mimic the two-part weight shift that currently happens
on the chip, which can lose information for small weight_exp.
check_results : bool, optional (Default: True)
Whether to check that the discretized weights fall in
the valid range for weights on the chip (-256 to 255).
"""
s = synapse_cfg.shift_bits
m = 2 ** (d(b"OA==", int) - s) - 1
w = np.round(w / 2.0 ** s).clip(-m, m).astype(dtype)
s2 = s + synapse_cfg.weight_exp
if lossy_shift:
if s2 < 0:
warnings.warn("Lost %d extra bits in weight rounding" % (-s2,))
# Round before `s2` right shift. Just shifting would floor
# everything resulting in weights biased towards being smaller.
w = (np.round(w * 2.0 ** s2) / 2 ** s2).clip(-m, m).astype(dtype)
shift(w, s2, out=w)
np.left_shift(w, d(b"Ng==", int), out=w)
else:
shift(w, d(b"Ng==", int) + s2, out=w)
if check_result:
ws = w // synapse_cfg.scale
assert np.all(ws <= d(b"MjU1", int)) and np.all(ws >= d(b"LTI1Ng==", int))
return w
def discretize_probe(probe, v_scale):
if probe.key == "voltage" and probe.weights is not None:
probe.weights /= v_scale