ForwardDiff.jl icon indicating copy to clipboard operation
ForwardDiff.jl copied to clipboard

Gradients with respect to struct fields?

Open NAThompson opened this issue 1 year ago • 2 comments

In Zygote.jl, we can take the gradient with respect to all fields of a struct foo passed through a function bar via

g = Zygote.gradient(f -> bar(f), foo)

Can this be done in ForwardDiff as well?

Reproducer:

using Zygote
using ForwardDiff

struct Foo
    x::Number
    t::Number
    c::Number
end

function bar(f::Foo)
    return f.x - f.c*f.t
end

foo = Foo(2, 3, 3e8)
println(foo)

g = Zygote.gradient(f -> bar(f), foo)

println(g)


g = ForwardDiff.gradient(f -> bar(f), foo)
println(g)

NAThompson avatar Aug 18 '24 15:08 NAThompson

Not straight forward. ForwardDiff differentiates w.r.t numbers and abstract vectors. You might be able to hack something together with generated functions.

KristofferC avatar Aug 19 '24 09:08 KristofferC

If you want to do this yourself, and only have a struct of real numbers, then it will be fairly simple:

julia> using ForwardDiff: Dual, partials

julia> make_dual(z::Foo) = Foo(Dual(z.x,1,0,0), Dual(z.t,0,1,0), Dual(z.c,0,0,1));

julia> get_Foo(dy::Dual) = (; x=partials(dy,1), t=partials(dy,2), c=partials(dy,3));

julia> get_Foo(bar(make_dual(foo)))
(x = 1.0, t = -3.0e8, c = -3.0)

julia> Zygote.gradient(bar, foo)[1]
(x = 1.0, t = -3.0e8, c = -3.0)

With a bit more work you could automate this to work with many structs of numbers, struct_gradient(f, x). And even allow structs of structs.

Allowing structs containing arrays will be much more tricky, basically thanks to ForwardDiff's chunk mode.

mcabbott avatar Aug 24 '24 19:08 mcabbott