ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
mutating calls
Is it possible to somehow define derivatives of in-place mutating functions? Eg axpy!(a,x,y) updates y to be y = a*x+y, and therefore it's derivative also needs to be updated.
Apologies if this is the documentation, I missed it.
In theory its fine (well there are some rules about things you have to do). In practice we don't do it because it will break Zygote, due to the fact that Zygote doesn't support mutation, and if we put in rules that support mutation then Zygote will claim to support mutation and then error. If we have an AD that does support mutation using ChainRules then we will need to workout a way for Zygote to opt-out of the mutating rules.
That's great news, though I'm not sure in what sense zygote doesn't support this - their buffer type for example mutates in-place https://github.com/FluxML/Zygote.jl/blob/84bf62ea18330389c64d0d918c91d7b897e1a5d8/src/lib/buffer.jl
The Buffer type is special. Its the only thing in Zygote that support mutation.
While I remember.
The two rules of pullbacks for mutating functions
-
You must undo whatever was changed by this operation in the primal value. E.g. A
setindex
's pullback must set that value back to what it was before. E.g. apush!
must have apop!
in the pullback. Worse case scenario is a full overwrite, which means you need to copy all the data before the primal computational, and then write it all back in during the pullback. -
The gradient for the mutation of a primal must be applied during the pullback via mutating the gradient for that thing. And it must only be applied that way. If it is also returned then that would have it counted twice.
This second rule does seem pose a problem for mutation support of functions that mutate a value and then don't return it. In that case we would not have a cotangent passed in to the pullback to mutate. But such functions are rare, and will (I believe) always be decomposable into a number more primative mutating operations that do return the mutated thing. Still we can't write custom rules for such things because if this.
And in zygote, what happens when I naively define this pullback for a mutating function that does return the argument it modifies (and the underlying code always uses the return value) ?
Currently I just define my rule so that it's not actually mutating in-place ...
If you do that then sometimes Zygote will silently return the wrong answer. I can't off hand tell you what times those are though
Example of what this looks like (if an AD did support mutation) following the rules posted above
function f(x)
for i in eachindex(x)
x[i] = x^2
end
return x
end
function rrule(typeof(f), x)
x_is_negative = x .< 0
function pullback(dy)
# need to undo the change to `x` incase it is used in another rule.
x .= sqrt.(x) .* (x_is_negative .* -1)
# if mutated on the forward need to mutate to store the derivative
dy .= 1/sqrt.(dy) # is this math right?
# return zero not dy as we have already accumulated that by mutating dy
return NO_FIELDS, ZeroTangent()
end
return f(x), pullback
end
So in Enzyme we support mutating calls, aliasing (c.f #350) and activity (c.f. #452). All of these problems are somewhat tightly correlated.
For me a motivating example is supporting GPU codes where outputs are mutated and there is no return value, I want to inplace accumulate the gradients (since memory pressure on the GPU is a huge issue).
One of the issues @wsmoses have been debating the responsibility of the caller. Since Enzyme doesn't use closures to capture the inputs, but expects the user to pass in both the shadow and the primal value. So if another part of the program mutates I think we currently expect the user to cache it.
Aliasing within the adjoint is solved by caching it, but since we can use LLVM alias-analysis we can limit the amount we need to cache.