ForwardDiff.jl icon indicating copy to clipboard operation
ForwardDiff.jl copied to clipboard

Gradients with respect to struct fields?

Open NAThompson opened this issue 1 year ago • 2 comments

In Zygote.jl, we can take the gradient with respect to all fields of a struct foo passed through a function bar via

g = Zygote.gradient(f -> bar(f), foo)

Can this be done in ForwardDiff as well?

Reproducer:

using Zygote
using ForwardDiff

struct Foo
    x::Number
    t::Number
    c::Number
end

function bar(f::Foo)
    return f.x - f.c*f.t
end

foo = Foo(2, 3, 3e8)
println(foo)

g = Zygote.gradient(f -> bar(f), foo)

println(g)


g = ForwardDiff.gradient(f -> bar(f), foo)
println(g)

NAThompson avatar Aug 18 '24 15:08 NAThompson