Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Provide `is_pointerlike(x)` and `are_aliased(x, y)` to use for runtime activity

Open gdalle opened this issue 3 months ago • 14 comments

Goal: enable users to check the aliasing between primal and shadow when runtime activity is on (https://enzymead.github.io/Enzyme.jl/stable/faq/#faq-runtime-activity):

Enabling runtime activity does therefore, come with a sharp edge, which is that if the computed derivative of a function is mutable, one must also check to see if the primal and shadow represent the same pointer, and if so the true derivative of the function is actually zero.

Related:

  • https://github.com/EnzymeAD/Enzyme.jl/pull/2559#discussion_r2342976028
  • https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/666#issuecomment-3279224954
  • https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/850

gdalle avatar Sep 17 '25 07:09 gdalle

this is likely a fairly generic question and I'd recommend you seek help on this topic in upstream julia if myself of @vchuravy are unable to help you figure out how to implement this in a timely fashion

wsmoses avatar Sep 17 '25 18:09 wsmoses

The reason why I suggest Enzyme should implement this is because

  • as evidenced by our previous discussions, it is far from trivial to even find Julia functions that do something like that
  • the precise semantics that you need are not obvious (for example the documentation on runtime activity suggests using ===, which is apparently insufficient)
  • every Enzyme user will need this functionality to ensure that their runtime activity code is correct, not just me inside DI

gdalle avatar Sep 17 '25 19:09 gdalle

I'm in no rush to implement this, I just took a look to fix https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/666 like you asked. And the answers I found make me think that if the checks are not available inside Enzyme, no other user besides DI will ever go to the trouble of trying to figure them out.

gdalle avatar Sep 17 '25 20:09 gdalle

I mean most users probably just have a single return type (e.g. array) function in which case the check is the simple one (if needed at all for their return type, e.g. not needed for float).

The complication for your case is wanting to do this as reflection in Julia, e.g. not for a fixed type

wsmoses avatar Sep 17 '25 20:09 wsmoses

I will also note, which goes to explain that DI will have significant work to implement, if you have say Tuple{Vector, Vector} as the return type, you must check each other the element tuples if they have equal pointers (if so it represents zero)

wsmoses avatar Sep 17 '25 21:09 wsmoses

True but if we manage to find a general criterion and put it inside DI, it forces every user to be safe when activating runtime activity. That section of the docs is rather easy to miss for someone who doesn't dive deep into Enzyme

gdalle avatar Sep 17 '25 21:09 gdalle

That's why I'm looking for a systematic implementation. And even if someone doesn't want to use DI, at least we can point there and say "that's how you should handle aliasing, we figured this out for you 🎁"

gdalle avatar Sep 17 '25 21:09 gdalle

So a part answer is something along this line:

function aliased(a::T, b::T) where T
	a === b
end

function recursive_aliased(a::T, b::T) where {T}
	if aliased(a, b)
		return true
	end
	# Otherwise check children
	child_aliasing = ntuple(Val(fieldcount(T))) do i
		@inline
		_a = getfield(a, i)
		_b = getfield(b, i)
		recursive_aliased(_a, _b)
	end
	any(child_aliasing)
end

I was originally trying to write something with Accessors (maybe @aplavin or @jw3126 have an idea)

This does handle https://github.com/EnzymeAD/Enzyme.jl/pull/2559#discussion_r2349713345 correctly, but does not handle the "fun" case of self-recursive types or types with uninitialized fields.

vchuravy avatar Sep 18 '25 13:09 vchuravy

Im not sure the alias function on own is the one you want here. Essentially because the "or of all aliasing" doesn't really give you actionable info (since you need the recursive alias info). I think essentially you want a dealias function that looks something like:

function dealias(primal, shadow)
   if not leaf
      recurse through elements
   else
      if mutable and aliasing
        return make_zero()
      else
        return shadow
      end
end

wsmoses avatar Sep 18 '25 13:09 wsmoses

A slightly better version is

begin
	function recursive_aliased(a::T, b::T, seen_a::IdSet{Any}, seen_b::IdSet{Any}) where {T}
		if aliased(a, b)
			return true
		end
		if a ∈ seen_a
			return false
		elseif b ∈ seen_b
			return false
		else
			push!(seen_a, a)
			push!(seen_b, b)
		end
		# Otherwise check children
		child_aliasing = ntuple(Val(fieldcount(T))) do i
			@inline
			if isdefined(a, i) && isdefined(b, i)
				_a = getfield(a, i)
				_b = getfield(b, i)
				return recursive_aliased(_a, _b, seen_a, seen_b)
			else
				return false
			end
		end
		any(child_aliasing)
	end
	function recursive_aliased(a, b)
		recursive_aliased(a, b, IdSet{Any}(), IdSet{Any}())
	end
end

but there are still corner cases that make this odd (e.g. how much can we rely on this coming from enzyme -- otherwise we must also consider mismatching types)

@wsmoses How do you define "is not leaf"? That's why I started with "===" using object equality for "mutable struct" and data-equality for "struct".

Ignoring self-recursive data structures:

using Enzyme
using ConstructionBase

function recursive_dealias(a::T, b::T) where {T}
	if aliased(a, b)
		return Enzyme.make_zero(b)
	end
	# Otherwise dealias children
	children = ntuple(Val(fieldcount(T))) do i
		@inline
		if isdefined(a, i) && isdefined(b, i)
			_a = getfield(a, i)
			_b = getfield(b, i)
			return recursive_dealias(_a, _b)
		elseif isdefined(b, i)
			return getfield(b, i)
		else
			error("Trying to dealias an uninitialized field")
		end
	end
	constructorof(T)(children...)
end

vchuravy avatar Sep 18 '25 14:09 vchuravy

@vchuravy re on Accessors: we do have recursive traversal tools, mostly in AccessorsExtra (the more usecases there are, the more motivation to upstream :) ):

Image

From a cursory look, not sure what exactly the target semantics of recursive_dealias is and whether Accessors are directly applicable... Should it really traverse all types no matter what they are?

There are RecursiveOfType for selecting by type and RecursivePred for arbitrary predicates. Typically, the default children traversal rule just works, but it can also be tweaked if needed.

aplavin avatar Sep 19 '25 01:09 aplavin

@aplavin sorry for the ping :) We should have that conversation on Slack.

The ting I was looking for was recursion over two objects of the same type (and structure) together. As an example a recursive addition or as above a recursive alias check

vchuravy avatar Sep 19 '25 07:09 vchuravy

Hi, gentle bump on this one. Would it be possible to get a starter implementation into Enzyme(Core), even if it is overly conservative? I could print a warning in DI when the check fails, without erroring just yet

gdalle avatar Oct 08 '25 10:10 gdalle

Hi, just following up here since Enzyme itself could benefit from such a function to handle correctness issues like https://github.com/EnzymeAD/Enzyme.jl/issues/2557

gdalle avatar Oct 16 '25 09:10 gdalle