MXNet.jl
MXNet.jl copied to clipboard
Representing the network as a graph.
Is there any way besides parsing the json to work with the network as a graph?
using MXNet
using LightGraphs
using JSON
net = @mx.chain mx.Variable(:data) =>
mx.Convolution(num_filter = 64, kernel = (3,3), pad = (1, 1), stride=(1,1)) =>
mx.LeakyReLU(act_type = :prelu) =>
mx.Convolution(num_filter = 128, kernel = (3,3), pad = (1, 1), stride=(1,1)) =>
mx.LeakyReLU(act_type = :prelu) =>
mx.Convolution(num_filter = 2, kernel = (3,3), pad = (1, 1), stride=(1,1)) =>
mx.LeakyReLU(act_type = :prelu) =>
mx.SoftmaxOutput(name=:softmax, multi_output=true)
js_net = JSON.parse(mx.to_json(net))
nodes = js_net["nodes"]
graph = DiGraph(length(nodes))
for (idx, node) in enumerate(nodes)
for edge in node["inputs"]
x, _ = edge
add_edge!(graph, x+1, idx)
end
end
I am using this right now to determine which parts of the network should be frozen.
As an example of how this can be useful, let's say you want to freeze the first few layers of a network
"""
Find nodes that are step away from the root
"""
function findNodes(graph, step)
Set(findNodes(graph, 1, step))
end
function findNodes(graph, n, step)
# anchor
if step < 0
return Int[]
end
# Adding ourselves here in case that we don't visit any future edges
nodes = Int[n]
for node in in_neighbors(graph, n)
push!(nodes, node)
end
for node in out_neighbors(graph, n)
result = findNodes(graph, node, step-1)
append!(nodes, result)
end
return nodes
end
for i in findNodes(graph, 4)
nodes[i]["attr"] = Dict("grad"=>"freeze")
end
net = mx.from_json(JSON.json(js_net), mx.SymbolicNode)
But it is kinda awkward to go through JSON instead of walking the network directly and using mx.set_attr.