ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
Differential Addition for Closures
I guess these are a special case of things with constrained values, maybe?
Anyway, the following
function make_foo(a_number)
return x -> a_number * x
end
foo = make_foo(5.0)
foo + Composite{typeof(foo)}(; a_number=4.0)
yields
ERROR: MethodError: no method matching var"#44#45"{Float64}(::Float64)
Stacktrace:
[1] macro expansion at /Users/willtebbutt/.julia/packages/ChainRulesCore/leg7n/src/differentials/composite.jl:0 [inlined]
[2] construct(::Type{var"#44#45"{Float64}}, ::NamedTuple{(:a_number,),Tuple{Float64}}) at /Users/willtebbutt/.julia/packages/ChainRulesCore/leg7n/src/differentials/composite.jl:168
[3] +(::var"#44#45"{Float64}, ::Composite{var"#44#45"{Float64},NamedTuple{(:a_number,),Tuple{Float64}}}) at /Users/willtebbutt/.julia/packages/ChainRulesCore/leg7n/src/differential_arithmetic.jl:96
[4] top-level scope at REPL[40]:1
[5] eval(::Module, ::Any) at ./boot.jl:331
[6] eval_user_input(::Any, ::REPL.REPLBackend) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.4/REPL/src/REPL.jl:86
[7] run_backend(::REPL.REPLBackend) at /Users/willtebbutt/.julia/packages/Revise/BqeJF/src/Revise.jl:1184
[8] top-level scope at REPL[2]:0
which I think stems from the fact that we're not able to create new instances of foo using the current approach.
@MasonProtter kindly pointed out on slack that one can construct new instances of closures as follows:
julia> macro new(T, vals...)
esc(Expr(:new, T, vals...))
end
@new (macro with 1 method)
julia> f(x) = y -> x*y + 1
f (generic function with 1 method)
julia> cl = f(2)
#5 (generic function with 1 method)
julia> cl(3)
7
julia> cl′ = @new typeof(cl) 3
#5 (generic function with 1 method)
julia> cl′(3)
10
Yep, this was left as pending-clear-need during the Composite PR
If you have found a need then we can added it.
In that case I'll leave this open -- I was trying to write a test for both rand_tangent and difference that involved a closure, as opposed to doing some actual work that I care about that involved a closure. If someone has a spare half hour, this would be a great thing to get sorted because it will almost certainly crop up at some point in the future.
I have a need for this. not for differentiable programming but for difference based optimization. My scheme to use chainrule's types to make e.g. Particle Swarm Optimizers etc usable without turning things in to vectors.
If i wanted to do basic neural archetecture search and I had this, I could do it really neatly.
Stetch based on https://github.com/FluxML/model-zoo/blob/master/vision/mnist/mlp.jl
using Flux
using MLDatasets: MNIST
using Statistics
using ChainRulesCore: difference # this needs to move out of FiniteDifferences.jl
##########################################
# Neural Net Code
const train_data, test_data = let
batchsize = 100
# Loading Dataset
xtrain, ytrain = MNIST.traindata(Float32)
xtest, ytest = MNIST.testdata(Float32)
xtrain = Flux.flatten(xtrain)
xtest = Flux.flatten(xtest)
ytrain = onehotbatch(ytrain, 0:9)
ytest = onehotbatch(ytest, 0:9)
# Batching
train_data = DataLoader(xtrain, ytrain, batchsize=batchsize, shuffle=true)
test_data = DataLoader(xtest, ytest, batchsize=batchsize)
train_data, test_data
end
function test_accuracy(model)
return mean(test_data) do (x, y)
sum(onecold(model(x)) .== onecold(y)) / size(x,2)
end
end
# the particle *is* a NN_constructor
function train_and_evaluate(NN_constructor)
model = NN_constructor()
opt = ADAM()
num_epochs = 10
loss = crossentropy
for ii in 1:num_epochs
Flux.train!(loss, params(model), train_data, opt)
end
return test_accuracy(model)
end
#######################################
# PSO
mutable struct ParticleMetadata{P, D}
position::P # e.g. a NN_constructor
velocity::D # e.g. a differential for a NN_constructor
pbest::P # e.g. a NN_constructor
pbest_score::Float64
end
function pso_evolve!(
pdata::ParticleMetadata, gbest;
g_rate=0.5, p_rate=0.3, momentum=0.9
)
pdata.velocity =
momentum * pdata.velocity +
g_rate * difference(gbest, pdata.position) +
p_rate * difference(pdata.pbest, pdata.position)
pdata.position += pdata.velocity
score = train_and_evaluate(pdata.position)
if score > pos.pbest_score
pdata.pbest = pdata.position
pdata.pbest_score = score
end
return score
end
function pso_optimise(particle_creator;
g_rate=0.5, p_rate=0.3, momentum=0.9,
num_particles=100, num_generations=10
)
tagged_particles = map(1:num_particles) do _
pos = particle_creator() # random initial position
# random initial velocity
vel = difference(particle_creator(), particle_creator())
score = train_and_evaluate(pos)
return ParticleMetadata(pos, vel, post, score)
end
generation = 1
g_score = -Inf
while true
new_g_best, new_g_score = maximum((p.pos, p.score) for p in tagged_particles)
if new_g_score >= g_score
g_score = new_g_score
g_best = new_g_best
end
# maybe put a callback to do some plotting etc here
generation > num_generations && return g_best
pso_evolve!.(tagged_particles, Ref(g_best); g_rate, p_rate, momentum)
end
end
#####################
# PSO invocation
function make_NN_constructor(min=10, max=100)
h1f, h2f, h3f = min .+ (max-min) .* rand(3)
function NN_constructor()
h1i, h2i, h3i = clamp.(ceil.(Int, h1f, h2f, h3f), min, max)
return Chain(
Dense(28*28, h1i, relu),
Dense(h1i, h2i, relu),
Dense(h2i, h3i, relu),
Dense(h3i, 10, softmax),
)
end
end
pso_optimise(make_NN_constructor)
Any that would be absolutely rad. Doing a heuristic global optimization over closures.