"""Export to GEXF for visualization of networks in Gephi."""
from collections import Mapping, namedtuple, OrderedDict, Sequence
from datetime import date
import weakref
import xml.etree.ElementTree as et
import nengo
try:
import nengo_spa as spa
except ImportError:
spa = None
import numpy as np
[docs]class DispatchTable(object):
"""A descriptor to dispatch to other methods depending on argument type.
How to use: assign the descriptor to a class attribute and use the
``register`` decorator to declare which functions to dispatch to
for specific types::
class MyClass(object):
dispatch = DispatchTable()
@dispatch.register(TypeA)
def handle_type_a(self, obj_of_type_a):
# ...
@dispatch.register(TypeB)
def handle_type_b(self, obj_of_type_b):
# ...
To then call the method for the appropriate type::
inst = MyClass()
inst.dispatch(obj_of_type_a_or_b)
If multiple methods would match (e.g. if *TypeB* inherits from *TypeA*),
the most specific method will be used (to be precise: the first type in the
method resolution order with a registered method will be used).
The *DispatchTable* descriptor accepts another *DispatchTable* as argument
which will be used as a fallback. This allows to inherit the dispatch
table and selectively overwrite methods like so::
class Inherited(MyClass):
dispatch = DispatchTable(MyClass.dispatch)
@dispatch.register(TypeA)
def alternate_type_a_handler(self, obj_of_type_a):
# ...
Finally, dispatch methods can also be changed on a per-instance basis::
inst.dispatch.register(TypeA, inst_type_a_handler)
"""
[docs] class InstDispatch(object):
"""Return value when accessing the dispatch table on an instance."""
__slots__ = ('param', 'inst', 'owner')
def __init__(self, param, inst, owner):
self.param = param
self.inst = inst
self.owner = owner
def __call__(self, obj):
for cls in obj.__class__.__mro__:
if cls in self.param.inst_type_table.get(self.inst, {}):
return self.param.inst_type_table[self.inst][cls](obj)
elif cls in self.param.type_table:
return self.param.type_table[cls](self.inst, obj)
elif self.param.parent is not None:
try:
return self.param.parent.__get__(
self.inst, self.owner)(obj)
except NotImplementedError:
pass
raise NotImplementedError(
"Nothing to dispatch to for type {}.".format(type(obj)))
def register(self, type_, fn):
if self.inst not in self.param.inst_type_table:
self.param.inst_type_table[self.inst] = (
weakref.WeakKeyDictionary())
table = self.param.inst_type_table[self.inst]
table[type_] = fn
return fn
def __init__(self, parent=None):
self.type_table = weakref.WeakKeyDictionary()
self.inst_type_table = weakref.WeakKeyDictionary()
self.parent = parent
def register(self, type_):
def _register(fn):
assert type_ not in self.type_table
self.type_table[type_] = fn
return fn
return _register
def __get__(self, inst, owner):
if inst is None:
return self
return self.InstDispatch(self, inst, owner)
[docs]class HierarchicalLabeler(object):
"""Obtains labels for objects in a Nengo network.
The names will include the network hierarchy.
Usage example::
labels = HierarchicalLabeler().get_labels(model)
"""
dispatch = DispatchTable()
def __init__(self):
self._names = None
@dispatch.register(Sequence)
def get_labels_from_sequence(self, seq):
base_name = self._names[seq]
for i, obj in enumerate(seq):
self._handle_found_name(obj, '{base_name}[{i}]'.format(
base_name=base_name, i=i))
@dispatch.register(Mapping)
def get_labels_from_mapping(self, mapping):
base_name = self._names[mapping]
for k in mapping:
obj = mapping[k]
self._handle_found_name(obj, '{base_name}[{k}]'.format(
base_name=base_name, k=k))
@dispatch.register(object)
def get_labels_from_object(self, obj):
pass
@dispatch.register(nengo.Network)
def get_labels_from_network(self, net):
if net in self._names:
base_name = self._names[net] + '.'
else:
base_name = ''
check_last = {
'ensembles', 'nodes', 'connections', 'networks', 'probes'}
check_never = {
'all_ensembles', 'all_nodes', 'all_connections', 'all_networks',
'all_objects', 'all_probes'}
for name in dir(net):
if (not name.startswith('_') and
name not in check_last | check_never):
try:
attr = getattr(net, name)
except AttributeError:
pass
else:
self._handle_found_name(attr, base_name + name)
for name in check_last:
attr = getattr(net, name)
self._handle_found_name(attr, base_name + name)
def _handle_found_name(self, obj, name):
if (isinstance(obj, (nengo.base.NengoObject, nengo.Network)) and
obj not in self._names):
self._names[obj] = name
self.dispatch(obj)
def get_labels(self, model):
self._names = weakref.WeakKeyDictionary()
self.dispatch(model)
return self._names
Attr = namedtuple('Attr', ['id', 'type', 'default'])
[docs]class GexfConverter(object):
"""Converts Nengo models into GEXF files.
This can be loaded in Gephi for visualization of the model graph.
Links:
* `Gephi <https://gephi.org/>`_
* `GEXF <https://github.com/gephi/gexf/wiki>`_
This class can be inherited from to customize the conversion or
alternatively the ``dispatch`` table can be changed on a
per-instance basis.
Note that probes are currently not included in the graph.
The following attributes will be stored on graph nodes:
* *type*: type of the Nengo object (e.g., *nengo.ensemble.Ensemble*),
* *net*: unique ID of the containing network,
* *net_label*: (possibly non-unique) label of the containing network,
* *size_in*: input size,
* *size_out*: output_size,
* *radius*: ensemble radius (unset for other nodes),
* *n_neurons*: number of neurons (0 for non-ensembles),
* *neuron_type*: string representation of the neuron type (unset for
non-ensembles).
The following attributes will be stored on graph edges:
* *pre_type*: type of the connection's pre object (e.g.,
*nengo.ensemble.Neurons*),
* *post_type*: type of the connection's post object (e.g.,
*nengo.ensemble.Neurons*),
* *synapse*: string representation of the synapse,
* *tau*: the tau parameter of the synapse if existent,
* *function*: string representation of the connection's function,
* *transform*: string representation of the connection's transform,
* *scalar_transform*: float representation of the transform if it is a
scalar,
* *learning_rule_type*: string representation of the connection's learning
rule type.
Parameters
----------
labeler : optional
Object with a ``get_labels`` method that returns a dictionary mapping
model objects to labels. If not given, a new `HierarchicalLabeler`
will be used.
hierarchical : bool, optional (default: False)
Whether to include information of the network hierarchy in the file.
Support for hierarchical graphs was removed in Gephi 0.9 and
hierarchical networks will be automatically flattened which leaves an
unconnected node for every network.
Examples
--------
Basic usage to write a GEXF file::
GexfConverter().convert(model).write('model.gexf')
"""
dispatch = DispatchTable()
node_attrs = OrderedDict((
('type', Attr(0, 'string', None)),
('net', Attr(1, 'long', None)),
('net_label', Attr(2, 'string', None)),
('size_in', Attr(3, 'integer', None)),
('size_out', Attr(4, 'integer', None)),
('radius', Attr(5, 'float', None)),
('n_neurons', Attr(6, 'integer', 0)),
('neuron_type', Attr(7, 'string', None)),
))
edge_attrs = OrderedDict((
('pre_type', Attr(0, 'string', None)),
('post_type', Attr(1, 'string', None)),
('synapse', Attr(2, 'string', None)),
('tau', Attr(3, 'float', None)),
('function', Attr(4, 'string', None)),
('transform', Attr(5, 'string', None)),
('scalar_transform', Attr(6, 'float', 1.)),
('learning_rule_type', Attr(7, 'string', None)),
))
def __init__(self, labeler=None, hierarchical=False):
if labeler is None:
labeler = HierarchicalLabeler()
self.labeler = labeler
self.hierarchical = hierarchical
self.version = (1, 3)
self.tag = 'draft'
# State used during processing of a model
# WeakKeyDict so we don't prevent garbage collection after conversion
# finished.
self._labels = weakref.WeakKeyDictionary()
self._net = None
[docs] def convert(self, model):
"""Convert a model to GEXF format.
Returns
-------
xml.etree.ElementTree.ElementTree
Converted model.
"""
self._labels = self.labeler.get_labels(model)
self._labels[model] = 'model'
return self.make_document(model)
[docs] def make_document(self, model):
"""Create the GEXF XML document from *model*.
This method is exposed so it can be overwritten in inheriting classes.
Invoke `convert` instead of this method to convert a model.
Returns
-------
xml.etree.ElementTree.ElementTree
Converted model.
"""
version = '.'.join(str(i) for i in self.version)
tag_version = version + self.tag
gexf = et.Element('gexf', {
'xmlns': 'http://www.gexf.net/' + tag_version,
'xmlns:xsi': 'http://www.w3.org/2001/XMLSchema-instance',
'xsi:schemaLocation': (
'http://www.gexf.net/' + tag_version + ' ' +
'http://www.gexf.net/' + tag_version + '/gexf.xsd'),
'version': version
})
meta = et.SubElement(gexf, 'meta', {
'lastmodifieddate': date.today().isoformat()})
creator = et.SubElement(meta, 'creator')
creator.text = self.get_typename(self)
graph = et.SubElement(gexf, 'graph', {'defaultedgetype': 'directed'})
graph.append(self.make_attr_defs('node', self.node_attrs))
graph.append(self.make_attr_defs('edge', self.edge_attrs))
graph.append(self.dispatch(model))
edges = et.SubElement(graph, 'edges')
for c in model.all_connections:
elem = self.dispatch(c)
if elem is not None:
edges.append(elem)
return et.ElementTree(gexf)
[docs] def make_attr_defs(self, cls, defs):
"""Generate an attribute definition block.
Parameters
----------
cls : str
Class the attribute definitions are for ('node' or 'edge').
defs : dict
Attribute definitions. Maps attribute names to `Attr` instances.
Returns
-------
xml.etree.ElementTree.Element
"""
attributes = et.Element('attributes', {'class': cls})
for k, d in defs.items():
attr = et.SubElement(attributes, 'attribute', {
'id': str(d.id),
'title': k,
'type': d.type,
})
if d.default is not None:
default = et.SubElement(attr, 'default')
default.text = str(d.default)
return attributes
[docs] def make_attrs(self, defs, attrs):
"""Generates a block of attribute values.
Parameters
----------
defs : dict
Attribute definitions. Maps attribute names to `Attr` instances.
attrs : dict
Mapping of attribute names to assigned values.
Returns
-------
xml.etree.ElementTree.Element
"""
values = et.Element('attvalues')
assert all(k in defs for k in attrs.keys())
for k, d in defs.items():
if k in attrs and attrs[k] is not None:
values.append(et.Element('attvalue', {
'for': str(d.id),
'value': str(attrs[k]),
}))
return values
[docs] def make_node(self, obj, **attrs):
"""Generate a node for *obj* with attributes *attrs*."""
tag_attrib = {'id': str(id(obj))}
if obj in self._labels:
tag_attrib['label'] = self._labels[obj]
node = et.Element('node', tag_attrib)
if len(attrs) > 0:
node.append(self.make_attrs(self.node_attrs, attrs))
return node
[docs] def make_edge(self, obj, source, target, **attrs):
"Edge for *obj* from *source* to *target* with attributes *attrs*."
tag_attrib = {
'id': str(id(obj)),
'source': str(id(source)),
'target': str(id(target))
}
edge = et.Element('edge', tag_attrib)
if len(attrs) > 0:
edge.append(self.make_attrs(self.edge_attrs, attrs))
return edge
@dispatch.register(nengo.Network)
def convert_network(self, net):
parent_net = self._net
self._net = net
nodes = et.Element('nodes')
leaves = net.ensembles + net.nodes + net.probes
for leave in leaves:
leave_elem = self.dispatch(leave)
if leave_elem is not None:
nodes.append(leave_elem)
if self.hierarchical:
for subnet in net.networks:
subnet_node = self.make_node(
subnet, type=self.get_typename(subnet), net=id(self._net),
net_label=self._labels.get(self._net, None),
n_neurons=subnet.n_neurons)
subnet_node.append(self.dispatch(subnet))
nodes.append(subnet_node)
else:
for subnet in net.networks:
nodes.extend(self.dispatch(subnet))
self._net = parent_net
return nodes
@dispatch.register(nengo.Ensemble)
def convert_ensemble(self, ens):
return self.make_node(
ens,
type=self.get_typename(ens),
net=id(self._net),
net_label=self._labels.get(self._net, None),
size_in=ens.dimensions,
size_out=ens.dimensions,
radius=ens.radius,
n_neurons=ens.n_neurons,
neuron_type=ens.neuron_type,
)
@dispatch.register(nengo.Node)
def convert_node(self, node):
return self.make_node(
node,
type=self.get_typename(node),
net=id(self._net),
net_label=self._labels.get(self._net, None),
size_in=node.size_in,
size_out=node.size_out,
)
@dispatch.register(nengo.Probe)
def convert_probe(self, probe):
return None
@dispatch.register(nengo.Connection)
def convert_connection(self, conn):
source = self.get_node_obj(conn.pre_obj)
target = self.get_node_obj(conn.post_obj)
return self.make_edge(
conn, source, target,
pre_type=self.get_typename(conn.pre_obj),
post_type=self.get_typename(conn.post_obj),
synapse=conn.synapse,
tau=conn.synapse.tau if hasattr(conn.synapse, 'tau') else None,
function=conn.function,
transform=conn.transform,
scalar_transform=(
conn.transform if np.isscalar(conn.transform) else None),
learning_rule_type=conn.learning_rule_type
)
[docs] def get_node_obj(self, obj):
"""Get an object with a corresponding graph node related to *obj*.
For certain objects like `nengo.ensemble.Neurons` or
`nengo.connection.LearningRule` no graph node will be created. This
function will resolve such an object to a related object that has a
corresponding graph node (e.g., the ensemble for a neurons object or
the pre object for a learning rule).
In `GexfConverter` this is used to make sure connections are between
the correct nodes and do not introduce unrelated dangling nodes.
"""
if isinstance(obj, nengo.ensemble.Neurons):
return obj.ensemble
elif isinstance(obj, nengo.connection.LearningRule):
return self.get_node_obj(obj.connection.pre_obj)
return obj
@classmethod
def get_typename(cls, obj):
tp = type(obj)
return tp.__module__ + '.' + tp.__name__
[docs]class CollapsingGexfConverter(GexfConverter):
"""Converts Nengo models into GEXF files with some collapsed networks.
See `GexfConverter` for general information on conversion to GEXF files.
This class will collapse certain networks to a single node in the
conversion.
Parameters
----------
to_collapse : sequence, optional
Network types to collapse, if not given the networks listed in
``NENGO_NETS`` and ``SPA_NETS`` will be collapsed. Note that
``SPA_NETS`` currently only contains networks from *nengo_spa*,
but not the *spa* module in core *nengo*.
labeler : optional
Object with a ``get_labels`` method that returns a dictionary mapping
model objects to labels. If not given, a new `HierarchicalLabeler`
will be used.
hierarchical : bool, optional (default: False)
Whether to include information of the network hierarchy in the file.
Support for hierarchical graphs was removed in Gephi 0.9 and
hierarchical networks will be automatically flattened which leaves an
unconnected node for every network.
"""
dispatch = DispatchTable(GexfConverter.dispatch)
NENGO_NETS = (
nengo.networks.CircularConvolution,
nengo.networks.EnsembleArray,
nengo.networks.Product)
if spa is None:
SPA_NETS = ()
else:
SPA_NETS = (
spa.networks.CircularConvolution,
spa.AssociativeMemory,
spa.Bind,
spa.Compare,
spa.Product,
spa.Scalar,
spa.State,
spa.Transcode)
def __init__(self, to_collapse=None, labeler=None, hierarchical=False):
super(CollapsingGexfConverter, self).__init__(
labeler=labeler, hierarchical=hierarchical)
if to_collapse is None:
to_collapse = self.NENGO_NETS + self.SPA_NETS
for cls in to_collapse:
self.dispatch.register(cls, self.convert_collapsed)
self.obj2collapsed = weakref.WeakKeyDictionary()
[docs] def convert_collapsed(self, net):
"""Used to convert a network into a collapsed graph node."""
nodes = et.Element('nodes')
nodes.append(self.make_node(
net, type=self.get_typename(net), net=id(self._net),
net_label=self._labels.get(self._net, None),
n_neurons=net.n_neurons))
self.obj2collapsed.update({
child: net for child in net.all_objects})
return nodes
def get_node_obj(self, obj):
obj = super(CollapsingGexfConverter, self).get_node_obj(obj)
return self.obj2collapsed.get(obj, obj)