import nengo
import numpy as np
from nengo.config import Default
from nengo.exceptions import ValidationError
from nengo.params import IntParam, Parameter
from nengo.utils.stdlib import checked_call
from nengo_spa.ast.symbolic import PointerSymbol
from nengo_spa.network import Network
from nengo_spa.semantic_pointer import SemanticPointer
from nengo_spa.vocabulary import VocabularyOrDimParam
class SpArrayExtractor:
def __init__(self, vocab):
self.vocab = vocab
def __call__(self, value):
if isinstance(value, PointerSymbol):
value = value.expr
if isinstance(value, str):
value = self.vocab.parse(value)
if isinstance(value, SemanticPointer):
value = value.v
return value
def make_sp_func(fn, vocab):
def sp_func(t, v):
return fn(t, SemanticPointer(v, vocab=vocab))
return sp_func
def make_parse_func(fn, vocab):
"""Create a function that calls func and parses the output in vocab."""
extractor = SpArrayExtractor(vocab)
def parse_func(*args):
return extractor(fn(*args))
return parse_func
class TranscodeFunctionParam(Parameter):
def coerce(self, obj, fn):
fn = super(TranscodeFunctionParam, self).coerce(obj, fn)
pointer_cls = (SemanticPointer, PointerSymbol)
if fn is None:
return fn
elif callable(fn):
return self.coerce_callable(obj, fn)
elif not obj.input_vocab and isinstance(fn, (str, pointer_cls)):
return fn
else:
raise ValidationError(
f"Invalid output type {type(fn)!r}", attr=self.name, obj=obj
)
def coerce_callable(self, obj, fn):
t = 0.0
if obj.input_vocab is not None:
args = (
t,
SemanticPointer(
np.zeros(obj.input_vocab.dimensions), vocab=obj.input_vocab
),
)
elif obj.size_in is not None:
args = (t, np.zeros(obj.size_in))
else:
args = (t,)
_, invoked = checked_call(fn, *args)
if not invoked:
if obj.input_vocab is not None:
raise ValidationError(
f"Transcode function {fn} is expected to accept exactly 2 "
"arguments: time as a float, and a SemanticPointer",
attr=self.name,
obj=obj,
)
else:
raise ValidationError(
f"Transcode function {fn} is expected to accept exactly 1 "
"or 2 arguments: time as a float, and optionally "
"the input data as NumPy array.",
attr=self.name,
obj=obj,
)
return fn
@classmethod
def to_node_output(cls, fn, input_vocab=None, output_vocab=None):
if fn is None:
return None
elif callable(fn):
if input_vocab is not None:
fn = make_sp_func(fn, input_vocab)
if output_vocab is not None:
fn = make_parse_func(fn, output_vocab)
return fn
elif isinstance(fn, (str, SemanticPointer, PointerSymbol)):
return SpArrayExtractor(output_vocab)(fn)
else:
raise ValueError(f"Invalid output type {type(fn)!r}")
[docs]class Transcode(Network):
"""Transcode from, to, and between Semantic Pointers.
This can thought of the equivalent of a `nengo.Node` for Semantic Pointers.
Either the *input_vocab* or the *output_vocab* argument must not be *None*.
(If you want both arguments to be *None*, use a normal `nengo.Node`.)
Which one of the parameters in the pairs *input_vocab/size_in* and
*output_vocab/size_out* is not set to *None*, determines whether a Semantic
Pointer input/output or a normal vector input/output is expected.
Parameters
----------
function : func, optional (Default: None)
Function that transforms the input Semantic Pointer to an output
Semantic Pointer. The function signature depends on *input_vocab*:
* If *input_vocab* is *None*, the allowed signatures are the same as
for a `nengo.Node`. Either ``function(t)`` or ``function(t, x)``
where *t* (float) is the current simulation time and *x* (NumPy
array) is the current input to transcode with size *size_in*.
* If *input_vocab* is not *None*, the signature has to be
``function(t, sp)`` where *t* (float) is the current simulation time
and *sp* (`.SemanticPointer`) is the current Semantic Pointer input.
The associated vocabulary can be obtained via ``sp.vocab``.
The allowed function return value depends on *output_vocab*:
* If *output_vocab* is *None*, the return value must be a NumPy array
(or equivalent) of size *size_out* or *None* (i.e. no return value)
if *size_out* is *None*.
* If *output_vocab* is not *None*, the return value can be either of:
NumPy array, `.SemanticPointer` instance, or an SemanticPointer
expression or symbolic expression as string that gets parsed with
the *output_vocab*.
input_vocab : Vocabulary, optional (Default: None)
Input vocabulary. Mutually exclusive with *size_in*.
output_vocab : Vocabulary, optional (Default: None)
Output vocabulary. Mutually exclusive with *size_out*.
size_in : int, optional (Default: None)
Input size. Mutually exclusive with *input_vocab*.
size_out : int, optional (Default: None)
Output size. Mutually exclusive with *output_vocab*.
**kwargs : dict
Additional keyword arguments passed to `nengo_spa.Network`.
Attributes
----------
input : nengo.Node
Input.
output : nengo.Node
Output.
"""
function = TranscodeFunctionParam(
"function", optional=True, default=None, readonly=True
)
input_vocab = VocabularyOrDimParam(
"input_vocab", optional=True, default=None, readonly=True
)
output_vocab = VocabularyOrDimParam(
"output_vocab", optional=True, default=None, readonly=True
)
size_in = IntParam("size_in", optional=True, default=None, readonly=True)
size_out = IntParam("size_out", optional=True, default=None, readonly=True)
def __init__(
self,
function=Default,
input_vocab=Default,
output_vocab=Default,
size_in=Default,
size_out=Default,
**kwargs,
):
super(Transcode, self).__init__(**kwargs)
# Vocabs need to be set before function which accesses vocab for
# validation.
self.input_vocab = input_vocab
self.output_vocab = output_vocab
self.size_in = size_in
self.size_out = size_out
if self.input_vocab is None and self.output_vocab is None:
raise ValidationError(
"At least one of input_vocab and output_vocab needs to be "
"set. If neither the input nor the output is a Semantic "
"Pointer, use a basic nengo.Node instead.",
self,
)
if self.input_vocab is not None and self.size_in is not None:
raise ValidationError(
"The input_vocab and size_in arguments are mutually " "exclusive.",
"size_in",
self,
)
if self.output_vocab is not None and self.size_out is not None:
raise ValidationError(
"The output_vocab and size_out arguments are mutually " "exclusive.",
"size_in",
self,
)
self.function = function
node_size_in = (
self.input_vocab.dimensions
if self.input_vocab is not None
else self.size_in
)
node_size_out = (
self.output_vocab.dimensions
if self.output_vocab is not None
else self.size_out
)
if self.function is None:
if node_size_in is None:
node_size_in = self.output_vocab.dimensions
node_size_out = None
with self:
self.node = nengo.Node(
TranscodeFunctionParam.to_node_output(
self.function, self.input_vocab, self.output_vocab
),
size_in=node_size_in,
size_out=node_size_out,
)
self.input = self.node
self.output = self.node
if self.input_vocab is not None:
self.declare_input(self.input, self.input_vocab)
if self.output_vocab is not None:
self.declare_output(self.output, self.output_vocab)