DiffEqFlux.jl
DiffEqFlux.jl copied to clipboard
(WIP) Implement a FowardDiff version of FFJORD
This implements a forward-mode version of FFJORD via ForwardDiff. However, Dual tag ordering issues are showing up, so it's failing.
using DiffEqFlux, DifferentialEquations, GalacticOptim, Distributions
nn = Chain(
Dense(1, 3, tanh),
Dense(3, 1, tanh),
) |> f32
tspan = (0.0f0, 10.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())
data_dist = Normal(6.0f0, 0.7f0)
train_data = rand(data_dist, 1, 100)
function loss(θ)
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
-mean(logpx)
end
adtype = GalacticOptim.AutoZygote()
res1 = DiffEqFlux.sciml_train(loss, ffjord_mdl.p, ADAM(0.1), adtype; maxiters=100)
also with AutoForwardDiff
I'm trying to fix this but the failing test looks really strange, and that's after I've fixed (I think) the problem with Dual tag. The e variable keeps being just true, even when monte_carlo is true. Zygote somehow manages to work with it but auto_jacvec does not. Does backpropagation of true have some special meaning in Zygote?
Yeah I'm not sure, and that's why I dropped this for a bit. I'm not convinced it's not a Zygote bug.
OK, in any case what I did to auto_jacvec was replacing broadcasting with map. Very likely Zygote's broadcast differentiation isn't perfect and avoiding unnecessary broadcasting in differentiated code seems like a good idea.
function auto_jacvec(f, x, v)
fval = f(map((xi, vi) -> Dual{typeof(ForwardDiff.Tag(f,eltype(x)))}(xi, vi), x, v))
map(u -> partials(u)[1], fval)
end
I've also changed a tag here to resemble more closely what ForwardDiff.jl does, I'm not sure why auto_jacvec has its own separate tag there.
Yeah, that's a better tag. Would be worth upstreaming.