Enzyme.jl
Enzyme.jl copied to clipboard
Add recursive map generalizing the make_zero mechanism
This is to explore functionality for realizing https://github.com/JuliaMath/QuadGK.jl/pull/120. The current draft cuts time and allocations in half for the MWE in that PR compared to the make_zero hack from the comments. Not sure if modifying the existing recursive_* functions like this is appropriate or whether it would be better to implement a separate deep_recursive_accumulate.
This probably breaks some existing uses of recursive_accumulate, like the Holomorphic derivative code, because recursive_accumulate now traverses most/all of the structure on its own and will double-accumulate when combined with the iteration over the seen IdDicts. Curious to see the total impact on the test suite.
This doesn't yet have any concept of seen and will thus double-accumulate if the structure has internal aliasing. That obviously needs to be fixed. Perhaps we can factor out and share the recursion code from make_zero.
A bit of a tangent, but perhaps a final version of this PR should include migrating ClosureVector to Enzyme from the QuadGK ext as suggested in https://github.com/JuliaMath/QuadGK.jl/pull/110#issuecomment-2323012377. Looks like that's the most relevant application of fully recursive accumulation at the moment.
Let me also throw out another suggestion: what if we implement a recursive generalization of broadcasting with an arbitrary number of arguments, i.e., recursive_broadcast!(f, a, b, c, ...) as a recursive generalization of a .= f.(b, c, ...), free of intermediate allocations whenever possible (and similarly an out-of-place recursive_broadcast(f, a, b, c...) generalizing f.(a, b, c...) that only materializes/allocates once if possible). That would enable more optimized custom rules with Duplicated args, such as having the QuadGK rule call the in-place version quadgk!(f!, result, segs...). Not sure if it would be hard to correctly handle aliasing without being overly defensive, or if that could mostly be taken care of by proper reuse of the existing broadcasting functionality.
Alright, I could take some feedback/discussion on this now.
- This implements a generic
recursive_mapfor mapping a function over the differentiable values in arbitrary tuples of identical data structures. There's an in-place equivalentrecursive_map!for mutable values, built on top of a bangbang-stylerecursive_map!!that works on arbitrary types and reuses all the mutable memory (similar to the oldmake_zero_immutable!but without code duplication). - The code is diffed with the old make_zero(!) code on github, but this is a complete rewrite and the diff will probably not be helpful for reviewing. Let me know if you want me to rename files or something to get rid github's diff view.
- The implementation is leaner and simpler than the old one, and even though
recursive_map{!!}aren't public I wrote extensive docstrings to clarify the spec for myself and others, so I don't think the code should be too difficult to review from scratch.
- The implementation is leaner and simpler than the old one, and even though
- I added fast paths such that new structs are allocated using
splatnewwith a tuple instead ofccallwith a vector in the common case where there are no undefined fields. This gives a substantial speedup in many cases, which is good sincerecursive_mapwill be called in hot loops in custom quadrature rules and the like. - I have refactored
make_zeroandmake_zero!to be minimal wrappers aroundrecursive_map{!}, without changing their public API.- To stay safe while doing such a big refactoring, I wrote extensive tests with ~full coverage of both the old and new implementations of make_zero(!) (a small number of edge case branches aren't covered only because I can't find a way to reach them from any public entry point).
- These tests uncovered quite a few bugs in the existing make_zero(!) implementations. See the following commit on a separate branch in my fork for the necessary fixes to get the old code to pass the new tests: https://github.com/danielwe/Enzyme.jl/commit/7a6ca9ff332b413c80a2cc540a92c79eb026f2eb. (Note, one of the tests still errors due to #1935.)
- I have not refactored
recursive_addandrecursive_accumulate, since these currently have different semantics where they don't recurse into mutable values. I'm happy to go ahead and do the refactoring if you're OK with changing their semantics.
TLDR: Should I rewrite recursive_add and recursive_accumulate to be based on recursive_map{!} and have full recursion semantics? Anything else?
@gdalle Promised to tag you when this was ready for review, but note that this PR only deals with the low-level, non-public guts of the implementation. I'll do the vector space wrapper in a separate PR as soon as this is merged (hopefully that won't be long, I really need that QuadGK rule for my research :triangular_ruler:)
Update for anyone who's following: I've implemented the VectorSpace wrapper, which prompted me to adjust the recursive_map implementation a bit, all for the better. It's looking good and will make writing custom higher-order rules as well as the DI wrappers a lot nicer for arbitrary types. However, it dawned on me that you probably want make_zero to be easily extensible by just adding methods, like what's already done in the StaticArrays extension. That will require a bit of redesign, nothing too hard, but I've got weekend plans so might not get to it until next week.
awesome, sorry I haven't had a chance to review let [just a bunch of schenanigans atm], I'll try to take a closer look next week and ping me if not
No worries! I restored the draft label when I realized there was a bit more to do and will remove it again once I think this is ready for review. No need to look at it until then, the current state here on github doesn't reflect what I'm working with locally anyway.
At long last, I think this one's ready for you to take a look. Hit me with any questions and concerns, from major design issues to bikeshedding over names.
I put both the implementation and tests in their own modules because they define a lot of helpers and I didn't want to pollute other modules' namespaces.
Codecov Report
Attention: Patch coverage is 87.69716% with 39 lines in your changes missing coverage. Please review.
Project coverage is 75.35%. Comparing base (
037dfed) to head (8bea510). Report is 355 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #1852 +/- ##
==========================================
+ Coverage 67.50% 75.35% +7.84%
==========================================
Files 31 56 +25
Lines 12668 16635 +3967
==========================================
+ Hits 8552 12535 +3983
+ Misses 4116 4100 -16
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
🚀 New features to boost your workflow:
- ❄ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
I'm still knee deep in 1.11 land and don't have cycles to review this immediately. @vchuravy can you take a look?
1.11 efforts deeply appreciated! Don't rush this. I'll keep using 1.10 and a local fork for my own needs, and occiasionally push small changes here as my tinkering surfaces new concerns/opportunities.
@danielwe is this the one now to review or is there a different related PR I should review first? (and also would you mind rebasing)
This is the one wrt. recursive maps and all that, but I've been continually refining stuff locally, so I need to both rebase and push the latest changes, hang on! Will remove draft status when ready for review.
Finally wrapped up and rebased this! I'll come back later and write a little blurb, but the code should be ready for review as-is.
Nice!
WARNING: Method definition inactive_type(Type{var"#s2630"} where var"#s2630"<:(RecursiveMapTests.CustomVector{T} where T)) in module RecursiveMapTests at /home/runner/work/Enzyme.jl/Enzyme.jl/test/recursive_maps.jl:867 overwritten at /home/runner/work/Enzyme.jl/Enzyme.jl/test/recursive_maps.jl:882.
Blurb time!
Let's start with a quick overview of the API. I also wrote exhaustive docstrings in the code (just my way of clarifying my thinking, and I hope it's useful for code review too), but this should be a more conversational introduction highlighting the main points.
Out-of-place usage
(out1::T, out2::T, ...) = recursive_map([seen::IdDict,] f, Val(Nout), (in1::T, in2::T, ...), [Val(copy_if_inactive), [isinactivetype]])
This generalizes map to recurse through inputs of arbitrary type and map f over all differentiable leaves. The function f should have methods (leaf_out1::U, leaf_out2::U, ...) = f(leaf_in1::U, leaf_in2::U, ...) for every type U whose instances are such leaves. In various docstrings and helper functions I refer to these leaf types as "vector types", since they are supposed to be the lowest-level types that natively implement vector space operations like +, zero, et cetera.
Note how this supports mapped functions f that have an arbitrary number of outputs, assembling them into Nout instances of the input type T rather than a single output containing leaves of types NTuple{Nout,U}. This was necessary to replace accumulate_into from internal_rules.jl. The number of outputs Nout should be passed to recursive_map as Val(Nout). To avoid complexity in the implementation, f has to return a Tuple{U} rather than just U even when there is only a single output. Since this is an internal API I figured this was an acceptable tradeoff.
Partially-in-place usage
(new_out1::T, new_out2::T, ...) = recursive_map([seen::IdDict,] f!!, (out1::T, out2::T, ...), (in1::T, in2::T, ...), [Val(copy_if_inactive), [isinactivetype]])
This form has bangbang-like semantics where mutable storage within the (out1, out2, ...) is mutated and reused, but there is no requirement that every differentiable value is in mutable storage. New immutable containers will be instantiated as needed and you need to use the returned values (new_out1, new_out2, ...) downstream.
To use this form, the mapped function f!! should have the same kind of out-of-place method as f above for all vector types, as well as mutating in-place methods for mutable, non-scalar vector types: (leaf_out1::U, leaf_out2::U, ...) = f!!(leaf_out1::U, leaf_out2::U, ..., leaf_in1::U, leaf_in2::U, ...). (The function should return the mutated outputs because technically you're also allowed to register non-mutable but non-scalar/non-isbits vector types that contain some mutable storage internally and are only partially mutated in place by this method. In such cases, you need to return the new outputs. This is hardly relevant in practice, and there's an argument for simplifying this and requiring that a vector type is either mutable or scalar (or both).)
The design of the whole API clicked for me the moment I realized that if you get this version right, both the in-place and out-of-place versions are just special cases, so you only need a single core of recursive functions with a f ewdifferent entry points for the various use cases.
In-place usage
If you pass types where all differentiable values live in mutable storage, partially-in-place already implements in-place behavior for you. The in-place function recursive_map! is just a wrapper that checks that the inputs satisfy this condition and throws an error otherwise.
recursive_map!([seen::IdDict,] f!!, (out1::T, out2::T, ...), (in1::T, in2::T, ...), [Val(copy_if_inactive), [isinactivetype]])::Nothing
f!! should have the same methods as described above for partially-in-place usage. (The parenthetically mentioned partially-in-place method variant for f!! is not relevant here, as the corresponding vector types would fail the in-place validity check.)
Optional arguments
seen::IdDict and Val{copy_if_inactive} should be familiar from the existing make_zero implementation.
seentracks object identity to reproduce the graph topology of the input objects, i.e., multiple references to the same object and/or cyclical references. One difference is that the in-place form also uses anIdDictrather than anIdSet, since the inputs and outputs are not generally the same objects even in in-place usage---outputs may just be preallocated and uninitialized storage. (Inputs and outputs being identical is of course still allowed, otherwisemake_zero!wouldn't work.)- I spent quite a bit of energy thinking about how to ensure/enforce consistency in graph topology with multiple inputs and outputs. I can elaborate more if needed, but the punchline is that the first input
in1is the reference version: other inputs must at least mirror all the aliasing seen in the first input; new outputs will reproduce the topology of the first input; and existing, reused outputs must have no more aliasing than the first input; all this so that every output leaf value is well-defined.
- I spent quite a bit of energy thinking about how to ensure/enforce consistency in graph topology with multiple inputs and outputs. I can elaborate more if needed, but the punchline is that the first input
Val{copy_if_inactive}decides whether non-differentiable subgraphs should be deepcopied or shared between inputs and outputs.isinactivetypeis a callable that decides exactly what these "non-differentiable subgraphs" are. This was a tricky one for multiple reasons, so I'll give it a dedicated paragraph below.
isinactivetype and IsInactive
There were three concerns:
-
Performance: This was never that important for
make_zero, but a key goal ofrecursive_mapis to be the basis for writing rules for higher-order functions like quadrature routines. When you need to map vector space operations over hundreds or even thousands of closure instances, performance is important.guaranteed_const_nongenended up being a performance bottleneck due to frequent type instabilities. I therefore decided to make the compile-time versionguaranteed_constthe default. However, this means new methods added toEnzymeRules.inactive_typewon't generally be picked up by subsequent calls torecursive_map. Hence, there should be a simple way to choose the_nongenversion when needed. -
Consistency: When the in-place
recursive_map!validates input/output, it should ignore non-differentiable subgraphs; what's important is that the differentiable values live in mutable storage. Hence this needs to takeisinactivetypeinto account. When using the compile-timeguaranteed_const, there exists a correspondingguaranteed_nonactivethat ensures this consistency. I addedguaranteed_nonactive_nongento complete the 2x2 matrix. Now all that's needed is an API for choosing betweenguaranteed_constandguaranteed_const_nongenthat also ensures that the corresponding version ofguaranteed_nonactive*is used for validation. -
Extras: Some uses of
recursive_map(!)need to terminate at (i.e., treat as inactive) additional parts of the object graph. Case in point:recursive_accumulate!, now renamed toaccumulate_seen!, used for holomorphic derivatives---this should terminate whenever it encounters a type that would be cached inseen. Hence it must be possible to pass a customtreat_this_as_inactive_toocallable to be used alongsideguaranteed_const*, and for consistency,recursive_map!argument validation should also hook into this.
Rather than a proliferation of optional arguments and internal voodoo, this called for a dedicated abstraction. Hence the IsInactive struct, which works as follows:
- Instantiation:
isinactivetype = IsInactive{runtime}([extra])whereruntimeis aBooland the optionalextrais yourtreat_this_as_inactive_toofunction described above. - The signature of
isinactivetypeas a callable isisinactivetype(T, [Val(nonactive)])::Bool - If
runtime == falseyou get compile-time evaluation (no*_nongen):isinactivetype(T)andisinactivetype(T, Val(false))callguaranteed_const(T) || extra(T)isinactivetype(T, Val(true))callsguaranteed_nonactive(T) || extra(T)
- If
runtime == trueyou get runtime evaluation (*_nongen):isinactivetype(T)andisinactivetype(T, Val(false))callguaranteed_const_nongen(T) || extra(T)isinactivetype(T, Val(true))callsguaranteed_nonactive_nongen(T) || extra(T)
This functionality is the reason behind the method overwriting warning pointed out by @vchuravy above:
WARNING: Method definition inactive_type(Type{var"#s2630"} where var"#s2630"<:(RecursiveMapTests.CustomVector{T} where T)) in module RecursiveMapTests at /home/runner/work/Enzyme.jl/Enzyme.jl/test/recursive_maps.jl:867 overwritten at /home/runner/work/Enzyme.jl/Enzyme.jl/test/recursive_maps.jl:882.
This warning comes from the test suite, not the package code. The test suite overwrites inactive_type a couple of times to verify that the various modes have the expected behavior (runtime = true picks up the new methods, runtime = false does not).
isinactivetype in practice
The optional argument ::Val{runtime}=Val(false) was added to make_zero(!) to let the user choose between compile-time and runtime inactivity checking. Under the hood, this passes IsInactive{runtime}() to recursive_map(!).
A word on generated functions
Unrolling loops over tuples and struct fields is crucial for type stability and performance. I tried to only use ntuple and recursive tuple peeling idioms combined with aggressive @inlineing, but couldn't consistently eliminate as many allocations as I could with simple and arguably less convoluted generated functions. Moreover, I saw some segfaults when using ntuple to assemble arguments for splatnew. Hence, I decided to use @generated for a few core functions. However, I made them as benign as I possibly could: there is no manual Expr handling, only quoted code where select compile-time constants are interpolated into the Base.@ntuple and Base.Cartesian.@nexprs macros. Perhaps the most questionable choice is the single use of @goto, but how else do you emulate break in an unrolled loop? (However, that loop is the least important one to unroll, so we can remove the @goto if you want.) I avoided gratuitous use of @generated when recursive tuple peeling worked just as well, such as for the getitems and setitems! helpers.
Generality/GPU compatibility
When the recursion hits a DenseArray{T} where T is a bitstype, it uses broadcasting rather than scalar iteration to recurse into the array. This means that everything should hopefully work out of the box for GPU arrays even when the array eltype is non-scalar (e.g., a struct or tuple). This is tested for JLArray, but I haven't actually checked if it works for CuArray or MtlArray. (The question is just whether the broadcasted closure compiles on GPU; I don't see why it shouldn't as long as the mapped function f works on GPU and you use IsInactive{false} to avoid type instabilities, but this remains to be seen. I don't know how well GPUCompiler does with recursion.)
In the common case where the GPU array eltype is scalar (i.e., float), the array is considered a leaf and dispatched directly to the mapped function f without further broadcast/recursion, so this case should definitely work on GPU, provided f itself does.
guaranteed_const_nongenended up being a performance bottleneck due to frequent type instabilities.
I did some more profiling, and this is not true anymore, possibly due to the @assume_effects annotations added in the meantime. Using *_nongen versions of the functions is now at least as fast as the alternative. I'll go ahead and simplify the implementation accordingly.
Wow, the format CI action really went to town here! I've accepted the suggestions that made sense, but Runic seems to have lost the plot in the remaining cases, suggesting arbitrary mashups of nearby lines. I can see what changes would actually be needed to make it happy, but I'll hold off on that while reporting this to Runic.jl.
OK, runic is finally happy. As explained in the linked issue, the nonsense suggestions were not due to Runic.jl itself, but to googleapis/code-suggester, which chops the full diff into hunks and turns them into code review suggestions, and which seems dangerously flawed. Perhaps the new version in #2276 is an improvement, but it's still not bug-free as shown in this issue: https://github.com/googleapis/code-suggester/issues/505
Thanks for reviewing! If you can hang on a moment before merging, I'd like to address your review comment above and return to a simpler single-output API.
Alright, there you go! Sorry that the diff from the last changes ended up being rather large despite most of the changes being quite superficial (the main thing is changing the type of the y argument from Union{Val{Nout},NTuple{Nout,T}} to Union{Nothing,Some{T}}, with concomitant simplifications because there output is no longer wrapped in a tuple). Please let me know if there's anything I can do to help make the reviewing as smooth as possible.
I realize it's a bit strange that some of the most thoroughly documented functions in the repo are now these fringe internal routines. Writing docs is just my way of thinking through the interface/contract. Happy to trim the docstrings to something more reasonable if you want me to.
Otherwise, this is good to go from my side.
Anything blocking this?
I realize it's a bit strange that some of the most thoroughly documented functions in the repo are now these fringe internal routines. Writing docs is just my way of thinking through the interface/contract. Happy to trim the docstrings to something more reasonable if you want me to.
They could probably be their own library.
They could probably be their own library.
That would have been great, but the implementation is too tightly coupled to Enzyme internals like active_reg_inner.
Anything blocking this?
Technically I don't think so? But I understand that it's a big diff to review and folks are busy, including me which is why I haven't been bumping lately. Locally I just rebase on main every now and then and keep working on my stuff.