ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

WIP: handle zero_tangent from cyclic data structures v2 Via premapping

Open oxinabox opened this issue 1 year ago • 2 comments

This is an alternative to https://github.com/JuliaDiff/ChainRulesCore.jl/pull/654 I think it is more promising and should minimize the type instability outside of cyclic locations

Basically here is the idea:

  1. Identify all objects that have multiple references to them
  2. Go through and construct tangents for everything;
  3. save the tangents for things with >1 reference into an IdDict keyed by primal
  4. when constructing tangent fields for things that have >1 reference to them, initially treat them as per undef elements, so set the tangent value as also undef or to a ZeroTangent
  5. go back through the tangent object and set the values for things with multiple references

This will also handle aliases outside of loops correctly.

oxinabox avatar Jan 25 '24 07:01 oxinabox