inferno
inferno copied to clipboard
Request: add user-defined attribute into containers.Graph
- inferno version: 0.4.0
- Python version: 3.7
- Operating System: Ubuntu-18.04
Description
I would like to set attributes on nodes of containers.Graph
. If the features is implemented, we can use it more Graph-likely.
What I Did(request code)
from inferno.extensions.containers import Graph
module = Graph()
module.add_input_node('input')
module.add_node('l1', nn.Linear(3, 18), previous='input', attributes={"hoge": [1, 2, 3]})
g = module.graph
print(g.nodes(data=True))
# [('input', {'is_input_node': True}), ('l1', {'attributes': {"hoge": [1, 2, 3]}})]
# or
# [('input', {'is_input_node': True}), ('l1', {"hoge": [1, 2, 3]})]
If I understand your code snippet correctly, you want to be able to pass a dict attributes
to add_node
, correct?
How would you use these in the nodes? Access them via node.attributes
?
Thank you for your replay.
If I understand your code snippet correctly, you want to be able to pass a dict attributes to add_node, correct?
Yes, your right.
How would you use these in the nodes? Access them via node.attributes?
I want to use them through native Networkx interfaces. Networkx provides the following interfaces. We can't access them via node.attributes
.
>> G.add_node(1, time='5pm')
>>> G.add_nodes_from([3], time='2pm')
>>> G.nodes[1]
{'time': '5pm'}
>>> G.nodes[1]['room'] = 714
>>> G.nodes.data()
NodeDataView({1: {'time': '5pm', 'room': 714}, 3: {'time': '2pm'}})
The following code is simple implementation (Graph class in inferno/extensions/containers/graph.py
), I guess.
def add_node(self, name, module, previous=None, **attr):
"""
Add a node to the graph.
Parameters
----------
name : str
Name of the node. Nodes are identified by their names.
attr : dict
Attributes of the Nodes.
module : torch.nn.Module
Torch module for this node.
previous : str or list of str
(List of) name(s) of the previous node(s).
Returns
-------
Graph
self
"""
assert isinstance(module, nn.Module)
self.add_module(name, module)
self.graph.add_node(name, **attr)
if previous is not None:
for _previous in pyu.to_iterable(previous):
self.add_edge(_previous, name)
return self