Provide `is_pointerlike(x)` and `are_aliased(x, y)` to use for runtime activity
Goal: enable users to check the aliasing between primal and shadow when runtime activity is on (https://enzymead.github.io/Enzyme.jl/stable/faq/#faq-runtime-activity):
Enabling runtime activity does therefore, come with a sharp edge, which is that if the computed derivative of a function is mutable, one must also check to see if the primal and shadow represent the same pointer, and if so the true derivative of the function is actually zero.
Related:
- https://github.com/EnzymeAD/Enzyme.jl/pull/2559#discussion_r2342976028
- https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/666#issuecomment-3279224954
- https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/850
this is likely a fairly generic question and I'd recommend you seek help on this topic in upstream julia if myself of @vchuravy are unable to help you figure out how to implement this in a timely fashion
The reason why I suggest Enzyme should implement this is because
- as evidenced by our previous discussions, it is far from trivial to even find Julia functions that do something like that
- the precise semantics that you need are not obvious (for example the documentation on runtime activity suggests using
===, which is apparently insufficient) - every Enzyme user will need this functionality to ensure that their runtime activity code is correct, not just me inside DI
I'm in no rush to implement this, I just took a look to fix https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/666 like you asked. And the answers I found make me think that if the checks are not available inside Enzyme, no other user besides DI will ever go to the trouble of trying to figure them out.
I mean most users probably just have a single return type (e.g. array) function in which case the check is the simple one (if needed at all for their return type, e.g. not needed for float).
The complication for your case is wanting to do this as reflection in Julia, e.g. not for a fixed type
I will also note, which goes to explain that DI will have significant work to implement, if you have say Tuple{Vector, Vector} as the return type, you must check each other the element tuples if they have equal pointers (if so it represents zero)
True but if we manage to find a general criterion and put it inside DI, it forces every user to be safe when activating runtime activity. That section of the docs is rather easy to miss for someone who doesn't dive deep into Enzyme
That's why I'm looking for a systematic implementation. And even if someone doesn't want to use DI, at least we can point there and say "that's how you should handle aliasing, we figured this out for you 🎁"
So a part answer is something along this line:
function aliased(a::T, b::T) where T
a === b
end
function recursive_aliased(a::T, b::T) where {T}
if aliased(a, b)
return true
end
# Otherwise check children
child_aliasing = ntuple(Val(fieldcount(T))) do i
@inline
_a = getfield(a, i)
_b = getfield(b, i)
recursive_aliased(_a, _b)
end
any(child_aliasing)
end
I was originally trying to write something with Accessors (maybe @aplavin or @jw3126 have an idea)
This does handle https://github.com/EnzymeAD/Enzyme.jl/pull/2559#discussion_r2349713345 correctly, but does not handle the "fun" case of self-recursive types or types with uninitialized fields.
Im not sure the alias function on own is the one you want here. Essentially because the "or of all aliasing" doesn't really give you actionable info (since you need the recursive alias info). I think essentially you want a dealias function that looks something like:
function dealias(primal, shadow)
if not leaf
recurse through elements
else
if mutable and aliasing
return make_zero()
else
return shadow
end
end
A slightly better version is
begin
function recursive_aliased(a::T, b::T, seen_a::IdSet{Any}, seen_b::IdSet{Any}) where {T}
if aliased(a, b)
return true
end
if a ∈ seen_a
return false
elseif b ∈ seen_b
return false
else
push!(seen_a, a)
push!(seen_b, b)
end
# Otherwise check children
child_aliasing = ntuple(Val(fieldcount(T))) do i
@inline
if isdefined(a, i) && isdefined(b, i)
_a = getfield(a, i)
_b = getfield(b, i)
return recursive_aliased(_a, _b, seen_a, seen_b)
else
return false
end
end
any(child_aliasing)
end
function recursive_aliased(a, b)
recursive_aliased(a, b, IdSet{Any}(), IdSet{Any}())
end
end
but there are still corner cases that make this odd (e.g. how much can we rely on this coming from enzyme -- otherwise we must also consider mismatching types)
@wsmoses How do you define "is not leaf"? That's why I started with "===" using object equality for "mutable struct" and data-equality for "struct".
Ignoring self-recursive data structures:
using Enzyme
using ConstructionBase
function recursive_dealias(a::T, b::T) where {T}
if aliased(a, b)
return Enzyme.make_zero(b)
end
# Otherwise dealias children
children = ntuple(Val(fieldcount(T))) do i
@inline
if isdefined(a, i) && isdefined(b, i)
_a = getfield(a, i)
_b = getfield(b, i)
return recursive_dealias(_a, _b)
elseif isdefined(b, i)
return getfield(b, i)
else
error("Trying to dealias an uninitialized field")
end
end
constructorof(T)(children...)
end
@vchuravy re on Accessors: we do have recursive traversal tools, mostly in AccessorsExtra (the more usecases there are, the more motivation to upstream :) ):
From a cursory look, not sure what exactly the target semantics of recursive_dealias is and whether Accessors are directly applicable... Should it really traverse all types no matter what they are?
There are RecursiveOfType for selecting by type and RecursivePred for arbitrary predicates. Typically, the default children traversal rule just works, but it can also be tweaked if needed.
@aplavin sorry for the ping :) We should have that conversation on Slack.
The ting I was looking for was recursion over two objects of the same type (and structure) together. As an example a recursive addition or as above a recursive alias check
Hi, gentle bump on this one. Would it be possible to get a starter implementation into Enzyme(Core), even if it is overly conservative? I could print a warning in DI when the check fails, without erroring just yet
Hi, just following up here since Enzyme itself could benefit from such a function to handle correctness issues like https://github.com/EnzymeAD/Enzyme.jl/issues/2557