MXNet.jl icon indicating copy to clipboard operation
MXNet.jl copied to clipboard

Representing the network as a graph.

Open vchuravy opened this issue 9 years ago • 1 comments

Is there any way besides parsing the json to work with the network as a graph?

using MXNet
using LightGraphs
using JSON

net = @mx.chain mx.Variable(:data) =>
mx.Convolution(num_filter = 64, kernel = (3,3), pad = (1, 1), stride=(1,1)) =>
mx.LeakyReLU(act_type = :prelu) => 
mx.Convolution(num_filter = 128, kernel = (3,3), pad = (1, 1), stride=(1,1)) =>
mx.LeakyReLU(act_type = :prelu) => 
mx.Convolution(num_filter = 2, kernel = (3,3), pad = (1, 1), stride=(1,1)) =>
mx.LeakyReLU(act_type = :prelu) =>
mx.SoftmaxOutput(name=:softmax, multi_output=true)

js_net = JSON.parse(mx.to_json(net))
nodes = js_net["nodes"]
graph = DiGraph(length(nodes))
for (idx, node) in enumerate(nodes)
    for edge in node["inputs"]
        x, _ = edge
        add_edge!(graph, x+1, idx)
    end
end

I am using this right now to determine which parts of the network should be frozen.

vchuravy avatar Apr 28 '16 06:04 vchuravy

As an example of how this can be useful, let's say you want to freeze the first few layers of a network

"""
Find nodes that are step away from the root
"""
function findNodes(graph, step)
    Set(findNodes(graph, 1, step))
end

function findNodes(graph, n, step)
    # anchor 
    if step < 0
        return Int[]
    end

    # Adding ourselves here in case that we don't visit any future edges
    nodes = Int[n]

    for node in in_neighbors(graph, n)
        push!(nodes, node)
    end

    for node in out_neighbors(graph, n)
        result = findNodes(graph, node, step-1)
        append!(nodes, result)
    end

    return nodes
end 

for i in findNodes(graph, 4)
    nodes[i]["attr"] = Dict("grad"=>"freeze")
end

net = mx.from_json(JSON.json(js_net), mx.SymbolicNode)

But it is kinda awkward to go through JSON instead of walking the network directly and using mx.set_attr.

vchuravy avatar Apr 28 '16 07:04 vchuravy