ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

Differential Addition for Closures

Open willtebbutt opened this issue 5 years ago • 4 comments

I guess these are a special case of things with constrained values, maybe?

Anyway, the following

function make_foo(a_number)
    return x -> a_number * x
end

foo = make_foo(5.0)
foo + Composite{typeof(foo)}(; a_number=4.0)

yields

ERROR: MethodError: no method matching var"#44#45"{Float64}(::Float64)
Stacktrace:
 [1] macro expansion at /Users/willtebbutt/.julia/packages/ChainRulesCore/leg7n/src/differentials/composite.jl:0 [inlined]
 [2] construct(::Type{var"#44#45"{Float64}}, ::NamedTuple{(:a_number,),Tuple{Float64}}) at /Users/willtebbutt/.julia/packages/ChainRulesCore/leg7n/src/differentials/composite.jl:168
 [3] +(::var"#44#45"{Float64}, ::Composite{var"#44#45"{Float64},NamedTuple{(:a_number,),Tuple{Float64}}}) at /Users/willtebbutt/.julia/packages/ChainRulesCore/leg7n/src/differential_arithmetic.jl:96
 [4] top-level scope at REPL[40]:1
 [5] eval(::Module, ::Any) at ./boot.jl:331
 [6] eval_user_input(::Any, ::REPL.REPLBackend) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.4/REPL/src/REPL.jl:86
 [7] run_backend(::REPL.REPLBackend) at /Users/willtebbutt/.julia/packages/Revise/BqeJF/src/Revise.jl:1184
 [8] top-level scope at REPL[2]:0

which I think stems from the fact that we're not able to create new instances of foo using the current approach.

willtebbutt avatar Jul 05 '20 23:07 willtebbutt

@MasonProtter kindly pointed out on slack that one can construct new instances of closures as follows:

julia> macro new(T, vals...)
           esc(Expr(:new, T, vals...))
       end
@new (macro with 1 method)

julia> f(x) = y -> x*y + 1
f (generic function with 1 method)

julia> cl = f(2)
#5 (generic function with 1 method)

julia> cl(3)
7

julia> cl′ = @new typeof(cl) 3
#5 (generic function with 1 method)

julia> cl′(3)
10

willtebbutt avatar Jul 05 '20 23:07 willtebbutt

Yep, this was left as pending-clear-need during the Composite PR If you have found a need then we can added it.

oxinabox avatar Jul 06 '20 18:07 oxinabox

In that case I'll leave this open -- I was trying to write a test for both rand_tangent and difference that involved a closure, as opposed to doing some actual work that I care about that involved a closure. If someone has a spare half hour, this would be a great thing to get sorted because it will almost certainly crop up at some point in the future.

willtebbutt avatar Jul 07 '20 10:07 willtebbutt

I have a need for this. not for differentiable programming but for difference based optimization. My scheme to use chainrule's types to make e.g. Particle Swarm Optimizers etc usable without turning things in to vectors.

If i wanted to do basic neural archetecture search and I had this, I could do it really neatly.

Stetch based on https://github.com/FluxML/model-zoo/blob/master/vision/mnist/mlp.jl

using Flux
using MLDatasets: MNIST
using Statistics
using ChainRulesCore: difference  # this needs to move out of FiniteDifferences.jl

##########################################
# Neural Net Code

const train_data, test_data = let
    batchsize = 100
    # Loading Dataset   
    xtrain, ytrain = MNIST.traindata(Float32)
    xtest, ytest = MNIST.testdata(Float32)
    
    xtrain = Flux.flatten(xtrain)
    xtest = Flux.flatten(xtest)
    ytrain = onehotbatch(ytrain, 0:9)
    ytest = onehotbatch(ytest, 0:9)

    # Batching
    train_data = DataLoader(xtrain, ytrain, batchsize=batchsize, shuffle=true)
    test_data = DataLoader(xtest, ytest, batchsize=batchsize)

    train_data, test_data
end

function test_accuracy(model)
    return mean(test_data) do (x, y)
        sum(onecold(model(x)) .== onecold(y)) / size(x,2)
    end
end

# the particle *is* a NN_constructor
function train_and_evaluate(NN_constructor)
    model = NN_constructor()
    opt = ADAM()
    num_epochs = 10
    loss = crossentropy
    for ii in 1:num_epochs
        Flux.train!(loss, params(model), train_data, opt)
    end
    return test_accuracy(model)
end

#######################################
# PSO


mutable struct ParticleMetadata{P, D}
    position::P  # e.g. a NN_constructor
    velocity::D  # e.g. a differential for a NN_constructor
    pbest::P  # e.g. a NN_constructor
    pbest_score::Float64
end


function pso_evolve!(
    pdata::ParticleMetadata, gbest;
    g_rate=0.5, p_rate=0.3, momentum=0.9
)
    pdata.velocity = 
        momentum * pdata.velocity +
        g_rate * difference(gbest, pdata.position) +
        p_rate * difference(pdata.pbest, pdata.position)

    pdata.position += pdata.velocity

    score = train_and_evaluate(pdata.position)
    if score > pos.pbest_score
        pdata.pbest = pdata.position
        pdata.pbest_score = score
    end
    return score
end

function pso_optimise(particle_creator;
    g_rate=0.5, p_rate=0.3, momentum=0.9,
    num_particles=100, num_generations=10
)

    tagged_particles = map(1:num_particles) do _
        pos = particle_creator()  # random initial position
        # random initial velocity
        vel = difference(particle_creator(), particle_creator())
        score = train_and_evaluate(pos)
        return ParticleMetadata(pos, vel, post, score)
    end

    generation = 1
    g_score = -Inf
    while true
        new_g_best, new_g_score = maximum((p.pos, p.score) for p in tagged_particles)
        if new_g_score >= g_score
            g_score = new_g_score
            g_best = new_g_best
        end
        
        # maybe put a callback to do some plotting etc here

        generation > num_generations && return g_best
        pso_evolve!.(tagged_particles, Ref(g_best); g_rate, p_rate, momentum)
    end
end

#####################
# PSO invocation

function make_NN_constructor(min=10, max=100)
    h1f, h2f, h3f = min .+ (max-min) .* rand(3)
    function NN_constructor()
        h1i, h2i, h3i = clamp.(ceil.(Int, h1f, h2f, h3f), min, max)
        return Chain(
            Dense(28*28, h1i, relu),
            Dense(h1i, h2i, relu),
            Dense(h2i, h3i, relu),
            Dense(h3i, 10, softmax),
        )
    end
end

pso_optimise(make_NN_constructor)

Any that would be absolutely rad. Doing a heuristic global optimization over closures.

oxinabox avatar Jul 18 '20 15:07 oxinabox