DiffEqFlux.jl icon indicating copy to clipboard operation
DiffEqFlux.jl copied to clipboard

(WIP) Implement a FowardDiff version of FFJORD

Open ChrisRackauckas opened this issue 4 years ago • 4 comments

This implements a forward-mode version of FFJORD via ForwardDiff. However, Dual tag ordering issues are showing up, so it's failing.

using DiffEqFlux, DifferentialEquations, GalacticOptim, Distributions

nn = Chain(
    Dense(1, 3, tanh),
    Dense(3, 1, tanh),
) |> f32
tspan = (0.0f0, 10.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())

data_dist = Normal(6.0f0, 0.7f0)
train_data = rand(data_dist, 1, 100)

function loss(θ)
    logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
    -mean(logpx)
end

adtype = GalacticOptim.AutoZygote()
res1 = DiffEqFlux.sciml_train(loss, ffjord_mdl.p, ADAM(0.1), adtype; maxiters=100)

also with AutoForwardDiff

ChrisRackauckas avatar Aug 28 '21 21:08 ChrisRackauckas

I'm trying to fix this but the failing test looks really strange, and that's after I've fixed (I think) the problem with Dual tag. The e variable keeps being just true, even when monte_carlo is true. Zygote somehow manages to work with it but auto_jacvec does not. Does backpropagation of true have some special meaning in Zygote?

mateuszbaran avatar Sep 15 '21 11:09 mateuszbaran

Yeah I'm not sure, and that's why I dropped this for a bit. I'm not convinced it's not a Zygote bug.

ChrisRackauckas avatar Sep 16 '21 15:09 ChrisRackauckas

OK, in any case what I did to auto_jacvec was replacing broadcasting with map. Very likely Zygote's broadcast differentiation isn't perfect and avoiding unnecessary broadcasting in differentiated code seems like a good idea.

function auto_jacvec(f, x, v)
    fval = f(map((xi, vi) -> Dual{typeof(ForwardDiff.Tag(f,eltype(x)))}(xi, vi), x, v))
    map(u -> partials(u)[1], fval)
end

I've also changed a tag here to resemble more closely what ForwardDiff.jl does, I'm not sure why auto_jacvec has its own separate tag there.

mateuszbaran avatar Sep 16 '21 20:09 mateuszbaran

Yeah, that's a better tag. Would be worth upstreaming.

ChrisRackauckas avatar Sep 16 '21 20:09 ChrisRackauckas