ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

Aliased differentials and Inplace Accumulation

Open oxinabox opened this issue 4 years ago • 2 comments

If a pullback is something like dx->(dx, dx) for dx being a reference type, e.g. like the pullback for +(::Array, ::Array), then the two outputs are the same object. They are aliased. Even though they are actually differentials for (usually) distinct (nonaliased) primal values. This causes inplace accumulation to act wrong.

See discussion here https://github.com/FluxML/Zygote.jl/pull/962#issuecomment-835884201 where @mcabbott was just adding inplace accumulation for getindex.

I am wondering if we need to require that if the primals are not aliased the differentials also need to not be aliased. This extra copies is slow if not inplace accumulation is not actually done though. we might want to make it configurable depending on the AD system. (this related to the config needs for #68)

I think the ideal solution would be copy-on-write objects. But doing that without language support is suffering.

oxinabox avatar May 11 '21 15:05 oxinabox

@willtebbutt and @mcabbott have both proposed doing something that looks somewhat like a thunk. Which delays the copy for the aliased differential until it is used. (e.g. in a call to add!!) Potentially this could even be @thunk(copy(dx)) I guess? Though maybe we only wanty that if inplacing. In which case it might be able to do it with an InplaceableThunk? Not sure. More thinking is needed.

The other option is something that gathers up a long chain of all attempts to accumulate against it, and then at the last minute when it has no choice calls copy and then does them all at once. Which is still thunk-like, and might be implemented as a wrapper around nested of InplaceableThunks?

oxinabox avatar May 11 '21 18:05 oxinabox

Yes, I think it's not quite @thunk(copy(dy)) since you want to be able to read the array freely, only when writing must you be careful.

But it's similar in that I think it could be handled by the code interfacing to rrules. If the rule for + is something like dy -> (Xerox(dy),Xerox(dy)), most rules could just receive the un-wrapped dy.original. Only those which return the same array dy -> (dy,) would need to leave it wrapped -- or perhaps the interface function tests what they return, and re-wraps if necessary.

The simplest version would just copy on the first attempt to write. My guess is that would get you most of the benefit. The repeated scalar indexing cases of https://github.com/FluxML/Zygote.jl/pull/962 still wouldn't involve any extra copies, as they didn't call the adjoint of +, they just need to know it's safe.

Delaying this as you suggest might mean that all other views of the same data have been discarded by the time you want to write. If the un-wrapping (and re-wrapping if ===) is all handled by one rrule interface function, then perhaps it would not be too hard to track the number of views in existence. But still some complexity. And holding onto the stack of things you intend to write there may also have a cost.

mcabbott avatar May 11 '21 19:05 mcabbott