Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Add recursive map generalizing the make_zero mechanism

Open danielwe opened this issue 1 year ago • 20 comments

This is to explore functionality for realizing https://github.com/JuliaMath/QuadGK.jl/pull/120. The current draft cuts time and allocations in half for the MWE in that PR compared to the make_zero hack from the comments. Not sure if modifying the existing recursive_* functions like this is appropriate or whether it would be better to implement a separate deep_recursive_accumulate.

This probably breaks some existing uses of recursive_accumulate, like the Holomorphic derivative code, because recursive_accumulate now traverses most/all of the structure on its own and will double-accumulate when combined with the iteration over the seen IdDicts. Curious to see the total impact on the test suite.

This doesn't yet have any concept of seen and will thus double-accumulate if the structure has internal aliasing. That obviously needs to be fixed. Perhaps we can factor out and share the recursion code from make_zero.

A bit of a tangent, but perhaps a final version of this PR should include migrating ClosureVector to Enzyme from the QuadGK ext as suggested in https://github.com/JuliaMath/QuadGK.jl/pull/110#issuecomment-2323012377. Looks like that's the most relevant application of fully recursive accumulation at the moment.


Let me also throw out another suggestion: what if we implement a recursive generalization of broadcasting with an arbitrary number of arguments, i.e., recursive_broadcast!(f, a, b, c, ...) as a recursive generalization of a .= f.(b, c, ...), free of intermediate allocations whenever possible (and similarly an out-of-place recursive_broadcast(f, a, b, c...) generalizing f.(a, b, c...) that only materializes/allocates once if possible). That would enable more optimized custom rules with Duplicated args, such as having the QuadGK rule call the in-place version quadgk!(f!, result, segs...). Not sure if it would be hard to correctly handle aliasing without being overly defensive, or if that could mostly be taken care of by proper reuse of the existing broadcasting functionality.

danielwe avatar Sep 18 '24 06:09 danielwe

Alright, I could take some feedback/discussion on this now.

  • This implements a generic recursive_map for mapping a function over the differentiable values in arbitrary tuples of identical data structures. There's an in-place equivalent recursive_map! for mutable values, built on top of a bangbang-style recursive_map!! that works on arbitrary types and reuses all the mutable memory (similar to the old make_zero_immutable! but without code duplication).
  • The code is diffed with the old make_zero(!) code on github, but this is a complete rewrite and the diff will probably not be helpful for reviewing. Let me know if you want me to rename files or something to get rid github's diff view.
    • The implementation is leaner and simpler than the old one, and even though recursive_map{!!} aren't public I wrote extensive docstrings to clarify the spec for myself and others, so I don't think the code should be too difficult to review from scratch.
  • I added fast paths such that new structs are allocated using splatnew with a tuple instead of ccall with a vector in the common case where there are no undefined fields. This gives a substantial speedup in many cases, which is good since recursive_map will be called in hot loops in custom quadrature rules and the like.
  • I have refactored make_zero and make_zero! to be minimal wrappers around recursive_map{!}, without changing their public API.
    • To stay safe while doing such a big refactoring, I wrote extensive tests with ~full coverage of both the old and new implementations of make_zero(!) (a small number of edge case branches aren't covered only because I can't find a way to reach them from any public entry point).
    • These tests uncovered quite a few bugs in the existing make_zero(!) implementations. See the following commit on a separate branch in my fork for the necessary fixes to get the old code to pass the new tests: https://github.com/danielwe/Enzyme.jl/commit/7a6ca9ff332b413c80a2cc540a92c79eb026f2eb. (Note, one of the tests still errors due to #1935.)
  • I have not refactored recursive_add and recursive_accumulate, since these currently have different semantics where they don't recurse into mutable values. I'm happy to go ahead and do the refactoring if you're OK with changing their semantics.

TLDR: Should I rewrite recursive_add and recursive_accumulate to be based on recursive_map{!} and have full recursion semantics? Anything else?

@gdalle Promised to tag you when this was ready for review, but note that this PR only deals with the low-level, non-public guts of the implementation. I'll do the vector space wrapper in a separate PR as soon as this is merged (hopefully that won't be long, I really need that QuadGK rule for my research :triangular_ruler:)

danielwe avatar Oct 07 '24 01:10 danielwe

Update for anyone who's following: I've implemented the VectorSpace wrapper, which prompted me to adjust the recursive_map implementation a bit, all for the better. It's looking good and will make writing custom higher-order rules as well as the DI wrappers a lot nicer for arbitrary types. However, it dawned on me that you probably want make_zero to be easily extensible by just adding methods, like what's already done in the StaticArrays extension. That will require a bit of redesign, nothing too hard, but I've got weekend plans so might not get to it until next week.

danielwe avatar Oct 11 '24 16:10 danielwe

awesome, sorry I haven't had a chance to review let [just a bunch of schenanigans atm], I'll try to take a closer look next week and ping me if not

wsmoses avatar Oct 11 '24 19:10 wsmoses

No worries! I restored the draft label when I realized there was a bit more to do and will remove it again once I think this is ready for review. No need to look at it until then, the current state here on github doesn't reflect what I'm working with locally anyway.

danielwe avatar Oct 11 '24 19:10 danielwe

At long last, I think this one's ready for you to take a look. Hit me with any questions and concerns, from major design issues to bikeshedding over names.

I put both the implementation and tests in their own modules because they define a lot of helpers and I didn't want to pollute other modules' namespaces.

danielwe avatar Oct 31 '24 19:10 danielwe

Codecov Report

Attention: Patch coverage is 87.69716% with 39 lines in your changes missing coverage. Please review.

Project coverage is 75.35%. Comparing base (037dfed) to head (8bea510). Report is 355 commits behind head on main.

Files with missing lines Patch % Lines
src/typeutils/recursive_maps.jl 91.76% 22 Missing :warning:
src/typeutils/recursive_add.jl 69.04% 13 Missing :warning:
src/analyses/activity.jl 0.00% 3 Missing :warning:
src/internal_rules.jl 0.00% 1 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1852      +/-   ##
==========================================
+ Coverage   67.50%   75.35%   +7.84%     
==========================================
  Files          31       56      +25     
  Lines       12668    16635    +3967     
==========================================
+ Hits         8552    12535    +3983     
+ Misses       4116     4100      -16     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov-commenter avatar Oct 31 '24 19:10 codecov-commenter

I'm still knee deep in 1.11 land and don't have cycles to review this immediately. @vchuravy can you take a look?

wsmoses avatar Nov 05 '24 20:11 wsmoses

1.11 efforts deeply appreciated! Don't rush this. I'll keep using 1.10 and a local fork for my own needs, and occiasionally push small changes here as my tinkering surfaces new concerns/opportunities.

danielwe avatar Nov 05 '24 21:11 danielwe

@danielwe is this the one now to review or is there a different related PR I should review first? (and also would you mind rebasing)

wsmoses avatar Dec 07 '24 17:12 wsmoses

This is the one wrt. recursive maps and all that, but I've been continually refining stuff locally, so I need to both rebase and push the latest changes, hang on! Will remove draft status when ready for review.

danielwe avatar Dec 07 '24 19:12 danielwe

Finally wrapped up and rebased this! I'll come back later and write a little blurb, but the code should be ready for review as-is.

danielwe avatar Jan 09 '25 02:01 danielwe

Nice!

WARNING: Method definition inactive_type(Type{var"#s2630"} where var"#s2630"<:(RecursiveMapTests.CustomVector{T} where T)) in module RecursiveMapTests at /home/runner/work/Enzyme.jl/Enzyme.jl/test/recursive_maps.jl:867 overwritten at /home/runner/work/Enzyme.jl/Enzyme.jl/test/recursive_maps.jl:882.

vchuravy avatar Jan 09 '25 08:01 vchuravy

Blurb time!

Let's start with a quick overview of the API. I also wrote exhaustive docstrings in the code (just my way of clarifying my thinking, and I hope it's useful for code review too), but this should be a more conversational introduction highlighting the main points.

Out-of-place usage

(out1::T, out2::T, ...) = recursive_map([seen::IdDict,] f, Val(Nout), (in1::T, in2::T, ...), [Val(copy_if_inactive), [isinactivetype]])

This generalizes map to recurse through inputs of arbitrary type and map f over all differentiable leaves. The function f should have methods (leaf_out1::U, leaf_out2::U, ...) = f(leaf_in1::U, leaf_in2::U, ...) for every type U whose instances are such leaves. In various docstrings and helper functions I refer to these leaf types as "vector types", since they are supposed to be the lowest-level types that natively implement vector space operations like +, zero, et cetera.

Note how this supports mapped functions f that have an arbitrary number of outputs, assembling them into Nout instances of the input type T rather than a single output containing leaves of types NTuple{Nout,U}. This was necessary to replace accumulate_into from internal_rules.jl. The number of outputs Nout should be passed to recursive_map as Val(Nout). To avoid complexity in the implementation, f has to return a Tuple{U} rather than just U even when there is only a single output. Since this is an internal API I figured this was an acceptable tradeoff.

Partially-in-place usage

(new_out1::T, new_out2::T, ...) = recursive_map([seen::IdDict,] f!!, (out1::T, out2::T, ...), (in1::T, in2::T, ...), [Val(copy_if_inactive), [isinactivetype]])

This form has bangbang-like semantics where mutable storage within the (out1, out2, ...) is mutated and reused, but there is no requirement that every differentiable value is in mutable storage. New immutable containers will be instantiated as needed and you need to use the returned values (new_out1, new_out2, ...) downstream.

To use this form, the mapped function f!! should have the same kind of out-of-place method as f above for all vector types, as well as mutating in-place methods for mutable, non-scalar vector types: (leaf_out1::U, leaf_out2::U, ...) = f!!(leaf_out1::U, leaf_out2::U, ..., leaf_in1::U, leaf_in2::U, ...). (The function should return the mutated outputs because technically you're also allowed to register non-mutable but non-scalar/non-isbits vector types that contain some mutable storage internally and are only partially mutated in place by this method. In such cases, you need to return the new outputs. This is hardly relevant in practice, and there's an argument for simplifying this and requiring that a vector type is either mutable or scalar (or both).)

The design of the whole API clicked for me the moment I realized that if you get this version right, both the in-place and out-of-place versions are just special cases, so you only need a single core of recursive functions with a f ewdifferent entry points for the various use cases.

In-place usage

If you pass types where all differentiable values live in mutable storage, partially-in-place already implements in-place behavior for you. The in-place function recursive_map! is just a wrapper that checks that the inputs satisfy this condition and throws an error otherwise.

recursive_map!([seen::IdDict,] f!!, (out1::T, out2::T, ...), (in1::T, in2::T, ...), [Val(copy_if_inactive), [isinactivetype]])::Nothing

f!! should have the same methods as described above for partially-in-place usage. (The parenthetically mentioned partially-in-place method variant for f!! is not relevant here, as the corresponding vector types would fail the in-place validity check.)

Optional arguments

seen::IdDict and Val{copy_if_inactive} should be familiar from the existing make_zero implementation.

  • seen tracks object identity to reproduce the graph topology of the input objects, i.e., multiple references to the same object and/or cyclical references. One difference is that the in-place form also uses an IdDict rather than an IdSet, since the inputs and outputs are not generally the same objects even in in-place usage---outputs may just be preallocated and uninitialized storage. (Inputs and outputs being identical is of course still allowed, otherwise make_zero! wouldn't work.)
    • I spent quite a bit of energy thinking about how to ensure/enforce consistency in graph topology with multiple inputs and outputs. I can elaborate more if needed, but the punchline is that the first input in1 is the reference version: other inputs must at least mirror all the aliasing seen in the first input; new outputs will reproduce the topology of the first input; and existing, reused outputs must have no more aliasing than the first input; all this so that every output leaf value is well-defined.
  • Val{copy_if_inactive} decides whether non-differentiable subgraphs should be deepcopied or shared between inputs and outputs.
  • isinactivetype is a callable that decides exactly what these "non-differentiable subgraphs" are. This was a tricky one for multiple reasons, so I'll give it a dedicated paragraph below.

isinactivetype and IsInactive

There were three concerns:

  • Performance: This was never that important for make_zero, but a key goal of recursive_map is to be the basis for writing rules for higher-order functions like quadrature routines. When you need to map vector space operations over hundreds or even thousands of closure instances, performance is important.

    guaranteed_const_nongen ended up being a performance bottleneck due to frequent type instabilities. I therefore decided to make the compile-time version guaranteed_const the default. However, this means new methods added to EnzymeRules.inactive_type won't generally be picked up by subsequent calls to recursive_map. Hence, there should be a simple way to choose the _nongen version when needed.

  • Consistency: When the in-place recursive_map! validates input/output, it should ignore non-differentiable subgraphs; what's important is that the differentiable values live in mutable storage. Hence this needs to take isinactivetype into account. When using the compile-time guaranteed_const, there exists a corresponding guaranteed_nonactive that ensures this consistency. I added guaranteed_nonactive_nongen to complete the 2x2 matrix. Now all that's needed is an API for choosing between guaranteed_const and guaranteed_const_nongen that also ensures that the corresponding version of guaranteed_nonactive* is used for validation.

  • Extras: Some uses of recursive_map(!) need to terminate at (i.e., treat as inactive) additional parts of the object graph. Case in point: recursive_accumulate!, now renamed to accumulate_seen!, used for holomorphic derivatives---this should terminate whenever it encounters a type that would be cached in seen. Hence it must be possible to pass a custom treat_this_as_inactive_too callable to be used alongside guaranteed_const*, and for consistency, recursive_map! argument validation should also hook into this.

Rather than a proliferation of optional arguments and internal voodoo, this called for a dedicated abstraction. Hence the IsInactive struct, which works as follows:

  • Instantiation: isinactivetype = IsInactive{runtime}([extra]) where runtime is a Bool and the optional extra is your treat_this_as_inactive_too function described above.
  • The signature of isinactivetype as a callable is isinactivetype(T, [Val(nonactive)])::Bool
  • If runtime == false you get compile-time evaluation (no *_nongen):
    • isinactivetype(T) and isinactivetype(T, Val(false)) call guaranteed_const(T) || extra(T)
    • isinactivetype(T, Val(true)) calls guaranteed_nonactive(T) || extra(T)
  • If runtime == true you get runtime evaluation (*_nongen):
    • isinactivetype(T) and isinactivetype(T, Val(false)) call guaranteed_const_nongen(T) || extra(T)
    • isinactivetype(T, Val(true)) callsguaranteed_nonactive_nongen(T) || extra(T)

This functionality is the reason behind the method overwriting warning pointed out by @vchuravy above:

WARNING: Method definition inactive_type(Type{var"#s2630"} where var"#s2630"<:(RecursiveMapTests.CustomVector{T} where T)) in module RecursiveMapTests at /home/runner/work/Enzyme.jl/Enzyme.jl/test/recursive_maps.jl:867 overwritten at /home/runner/work/Enzyme.jl/Enzyme.jl/test/recursive_maps.jl:882.

This warning comes from the test suite, not the package code. The test suite overwrites inactive_type a couple of times to verify that the various modes have the expected behavior (runtime = true picks up the new methods, runtime = false does not).

isinactivetype in practice

The optional argument ::Val{runtime}=Val(false) was added to make_zero(!) to let the user choose between compile-time and runtime inactivity checking. Under the hood, this passes IsInactive{runtime}() to recursive_map(!).

A word on generated functions

Unrolling loops over tuples and struct fields is crucial for type stability and performance. I tried to only use ntuple and recursive tuple peeling idioms combined with aggressive @inlineing, but couldn't consistently eliminate as many allocations as I could with simple and arguably less convoluted generated functions. Moreover, I saw some segfaults when using ntuple to assemble arguments for splatnew. Hence, I decided to use @generated for a few core functions. However, I made them as benign as I possibly could: there is no manual Expr handling, only quoted code where select compile-time constants are interpolated into the Base.@ntuple and Base.Cartesian.@nexprs macros. Perhaps the most questionable choice is the single use of @goto, but how else do you emulate break in an unrolled loop? (However, that loop is the least important one to unroll, so we can remove the @goto if you want.) I avoided gratuitous use of @generated when recursive tuple peeling worked just as well, such as for the getitems and setitems! helpers.

Generality/GPU compatibility

When the recursion hits a DenseArray{T} where T is a bitstype, it uses broadcasting rather than scalar iteration to recurse into the array. This means that everything should hopefully work out of the box for GPU arrays even when the array eltype is non-scalar (e.g., a struct or tuple). This is tested for JLArray, but I haven't actually checked if it works for CuArray or MtlArray. (The question is just whether the broadcasted closure compiles on GPU; I don't see why it shouldn't as long as the mapped function f works on GPU and you use IsInactive{false} to avoid type instabilities, but this remains to be seen. I don't know how well GPUCompiler does with recursion.)

In the common case where the GPU array eltype is scalar (i.e., float), the array is considered a leaf and dispatched directly to the mapped function f without further broadcast/recursion, so this case should definitely work on GPU, provided f itself does.

danielwe avatar Jan 09 '25 17:01 danielwe

  • guaranteed_const_nongen ended up being a performance bottleneck due to frequent type instabilities.

I did some more profiling, and this is not true anymore, possibly due to the @assume_effects annotations added in the meantime. Using *_nongen versions of the functions is now at least as fast as the alternative. I'll go ahead and simplify the implementation accordingly.

danielwe avatar Jan 10 '25 23:01 danielwe

Wow, the format CI action really went to town here! I've accepted the suggestions that made sense, but Runic seems to have lost the plot in the remaining cases, suggesting arbitrary mashups of nearby lines. I can see what changes would actually be needed to make it happy, but I'll hold off on that while reporting this to Runic.jl.

danielwe avatar Jan 22 '25 04:01 danielwe

OK, runic is finally happy. As explained in the linked issue, the nonsense suggestions were not due to Runic.jl itself, but to googleapis/code-suggester, which chops the full diff into hunks and turns them into code review suggestions, and which seems dangerously flawed. Perhaps the new version in #2276 is an improvement, but it's still not bug-free as shown in this issue: https://github.com/googleapis/code-suggester/issues/505

danielwe avatar Jan 22 '25 21:01 danielwe

Thanks for reviewing! If you can hang on a moment before merging, I'd like to address your review comment above and return to a simpler single-output API.

danielwe avatar Mar 10 '25 15:03 danielwe

Alright, there you go! Sorry that the diff from the last changes ended up being rather large despite most of the changes being quite superficial (the main thing is changing the type of the y argument from Union{Val{Nout},NTuple{Nout,T}} to Union{Nothing,Some{T}}, with concomitant simplifications because there output is no longer wrapped in a tuple). Please let me know if there's anything I can do to help make the reviewing as smooth as possible.

I realize it's a bit strange that some of the most thoroughly documented functions in the repo are now these fringe internal routines. Writing docs is just my way of thinking through the interface/contract. Happy to trim the docstrings to something more reasonable if you want me to.

Otherwise, this is good to go from my side.

danielwe avatar Mar 13 '25 00:03 danielwe

Anything blocking this?

I realize it's a bit strange that some of the most thoroughly documented functions in the repo are now these fringe internal routines. Writing docs is just my way of thinking through the interface/contract. Happy to trim the docstrings to something more reasonable if you want me to.

They could probably be their own library.

ChrisRackauckas avatar May 19 '25 09:05 ChrisRackauckas

They could probably be their own library.

That would have been great, but the implementation is too tightly coupled to Enzyme internals like active_reg_inner.

Anything blocking this?

Technically I don't think so? But I understand that it's a big diff to review and folks are busy, including me which is why I haven't been bumping lately. Locally I just rebase on main every now and then and keep working on my stuff.

danielwe avatar May 19 '25 14:05 danielwe