ChainRulesTestUtils.jl
ChainRulesTestUtils.jl copied to clipboard
structural rand_tangent
rand_tangent
prefers to generate natural tangents where possible. Sometimes, it's preferable to generate structural tangents. What's the feeling about the best way to accomodate this need?
Probably introduce a rand_structural_tangent
?
Which we have code for so can trivially reuse it.
One complicating factor is how deep do you go? E.g what about a
struct Foo
x ::Diagonal
y::Float64
Does x
get the natural or structural tangent?
Do we introduce a max-depth parameter? what does it default to?
What if it is different depth for different fields?
At what point should we just require people to construct their own tangent?
API wise
passing in a primal x
is conceptually the same as x ⊢ Auto()
(though i think actually that doesn't in-fact work right now. Auto
is only setup for output_tangent
)
So we could introduce a new x ⊢ AutoStruct()
or something
Do we introduce a max-depth parameter? what does it default to?
I wasn't imagining doing this -- usually if I want a structural tangent I just want everything to be a structural tangent 🤷 the issue is just determining which types are considered primitive (something I don't believe we've really tied down?), meaning that they don't ever use the structural tangent, so have to return whatever their prescribed natural tangent type is.
At what point should we just require people to construct their own tangent?
If they want a mix of the two?