Source code for nengo.builder.network

import contextlib
import logging

import numpy as np

import nengo.utils.numpy as npext
from nengo.builder import Builder
from nengo.connection import Connection
from nengo.ensemble import Ensemble
from nengo.network import Network
from nengo.node import Node
from nengo.probe import Probe
from nengo.utils.progress import Progress

logger = logging.getLogger(__name__)
nullcontext = contextlib.contextmanager(lambda: (yield))


[docs]@Builder.register(Network) # noqa: C901 def build_network(model, network, progress=None): """Builds a `.Network` object into a model. The network builder does this by mapping each high-level object to its associated signals and operators one-by-one, in the following order: 1. Ensembles, nodes, neurons 2. Subnetworks (recursively) 3. Connections, learning rules 4. Probes Before calling any of the individual objects' build functions, random number seeds are assigned to objects that did not have a seed explicitly set by the user. Whether the seed was assigned manually or automatically is tracked, and the decoder cache is only used when the seed is assigned manually. Parameters ---------- model : Model The model to build into. network : Network The network to build. progress : Progress, optional Object used to track the build progress. Note that this will only affect top-level networks. Notes ----- Sets ``model.params[network]`` to ``None``. """ if model.toplevel is None: model.toplevel = network seed_network(network, seeds=model.seeds, seeded=model.seeded) if progress is not None: # number of sub-objects, plus 1 to account for this network progress.max_steps = len(network.all_objects) + 1 def build_callback(obj): if isinstance(obj, tuple(network.objects)): progress.step() model.build_callback = build_callback if progress is None: progress = Progress() # dummy progress # Set config old_config = model.config model.config = network.config # If this is the toplevel network, enter the decoder cache context = model.decoder_cache if model.toplevel is network else nullcontext() with context, progress: logger.debug("Network step 1: Building ensembles and nodes") for obj in network.ensembles + network.nodes: model.build(obj) logger.debug("Network step 2: Building subnetworks") for subnetwork in network.networks: model.build(subnetwork) logger.debug("Network step 3: Building connections") for conn in network.connections: # NB: we do these in the order in which they're defined, and build # the learning rule in the connection builder. Because learning # rules are attached to connections, the connection that contains # the learning rule (and the learning rule) are always built # *before* a connection that attaches to that learning rule. # Therefore, we don't have to worry about connection ordering here. # TODO: Except perhaps if the connection being learned # is in a subnetwork? model.build(conn) logger.debug("Network step 4: Building probes") for probe in network.probes: model.build(probe) if context is model.decoder_cache: model.decoder_cache.shrink() if model.toplevel is network: progress.step() model.build_callback = None # Unset config model.config = old_config model.params[network] = None
[docs]def seed_network(network, seeds, seeded, base_rng=np.random): """Populate seeding dictionaries for all objects in a network. This includes all subnetworks. .. versionadded:: 3.0.0 Parameters ---------- network : Network The network containing all objects to set seeds for. seeds : {object: int} Pre-existing map from objects to seeds for those objects. Will be modified in-place, but entries will not be overwritten if already set. seeded : {object: bool} Pre-existing map from objects to a boolean indicating whether they have a fixed seed either themselves or from a parent network (True), or whether the seed is randomly generated (False). Will be modified in-place, but entries will not be overwritten if already set. base_rng : np.random.RandomState Random number generator to use to set the seeds. """ # seed this base network _set_seed(seeds, network, base_rng) _set_seeded(seeded, network) # seed all sub-objects _seed_network(network, seeds, seeded)
def _seed_network(network, seeds, seeded): """Recursive helper to set seeds for all child objects and subnetworks.""" rng = np.random.RandomState(seeds[network]) # Put probes last so that they don't influence other seeds sorted_types = (Connection, Ensemble, Network, Node, Probe) assert all(tp in sorted_types for tp in network.objects) # assign seeds to all child objects for obj_type in sorted_types: for obj in network.objects[obj_type]: _set_seed(seeds, obj, rng) _set_seeded(seeded, obj, parent=network) # assign seeds to subnetwork objects for subnetwork in network.networks: _seed_network(subnetwork, seeds, seeded) def _set_seed(seeds, obj, rng): # Generate a seed no matter what, so that setting a seed or not on # one object doesn't affect the seeds of other objects. seed = rng.randint(npext.maxint) if obj in seeds: return # do not overwrite an existing seed elif getattr(obj, "seed", None) is not None: seeds[obj] = obj.seed else: seeds[obj] = seed def _set_seeded(seeded, obj, parent=None): # do not overwrite an existing value, since this value says how the # original seed was assigned (deterministically or randomly), and if we # re-determine this, we might be wrong (e.g. if obj.seed has changed) if obj not in seeded: seeded[obj] = ( getattr(obj, "seed", None) is not None or parent is not None and seeded[parent] )