ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Handling rules that construct pullbacks via mutations of copies of their inputs

Open Roger-luo opened this issue 4 years ago • 3 comments

OK, I admit I can't find a better title for this. But I feel this should be brought up. See previous related issues in #381 #380 #46

@GiggleLiu keeps wanting to make all the return type of pullback to be the same as its input, that's why he was asking how to fix the problem in #381 for all possible cases. However, I don't think this would work with the current thunk mechanism.

I'm thinking there should be a mutability trait to handle this task, and a BangBang version of setindex! (which is setindex!!) so when the rules writer requires the interface, ChainRules can check and error.

Meanwhile, I think part of the problem of using thunk is because current AD compilers like Zygote are not able to do optimization if the compiler is able to do what thunk does, e.g Yota can figure out the gradient accumulation and unused gradient I believe, then it is not necessary to do it lazily.

cc: @oxinabox @sethaxen

Roger-luo avatar Feb 20 '21 00:02 Roger-luo

AFAICT this has nothing to do with the thunking mechanism.

oxinabox avatar Feb 20 '21 00:02 oxinabox

AFAICT this has nothing to do with the thunking mechanism.

I'm just saying one cannot make the input and output type of a pullback to be the same, or in this case, creating thunk will be disallowed. which is what @GiggleLiu proposed

Roger-luo avatar Feb 20 '21 01:02 Roger-luo

@Roger-luo thanks for openning an issue, this is an important topic that deserves extensive discussion. I wish someone can convince me that why adjoint type can not always be the type as the output type (or Nothing). I am unimpressed by thunk, because pytorch can use a single type tensor to get very good performance. To me, it seems "returning the same type or nothing" can do most of the job, this will make it so much easier for debugging.

Let me add a using case. When doing sparse matrix-vector multiplication, many julia AD will give me a dense array as the adjoint and blow up the memory. This is a case that I'd rather it errors.

GiggleLiu avatar Feb 20 '21 02:02 GiggleLiu