Gradients with respect to struct fields?
In Zygote.jl, we can take the gradient with respect to all fields of a struct foo passed through a function bar via
g = Zygote.gradient(f -> bar(f), foo)
Can this be done in ForwardDiff as well?
Reproducer:
using Zygote
using ForwardDiff
struct Foo
x::Number
t::Number
c::Number
end
function bar(f::Foo)
return f.x - f.c*f.t
end
foo = Foo(2, 3, 3e8)
println(foo)
g = Zygote.gradient(f -> bar(f), foo)
println(g)
g = ForwardDiff.gradient(f -> bar(f), foo)
println(g)
Not straight forward. ForwardDiff differentiates w.r.t numbers and abstract vectors. You might be able to hack something together with generated functions.
If you want to do this yourself, and only have a struct of real numbers, then it will be fairly simple:
julia> using ForwardDiff: Dual, partials
julia> make_dual(z::Foo) = Foo(Dual(z.x,1,0,0), Dual(z.t,0,1,0), Dual(z.c,0,0,1));
julia> get_Foo(dy::Dual) = (; x=partials(dy,1), t=partials(dy,2), c=partials(dy,3));
julia> get_Foo(bar(make_dual(foo)))
(x = 1.0, t = -3.0e8, c = -3.0)
julia> Zygote.gradient(bar, foo)[1]
(x = 1.0, t = -3.0e8, c = -3.0)
With a bit more work you could automate this to work with many structs of numbers, struct_gradient(f, x). And even allow structs of structs.
Allowing structs containing arrays will be much more tricky, basically thanks to ForwardDiff's chunk mode.