djl
djl copied to clipboard
pytorch doesn't clear out gradients between `backward()` calls in the gradient collector
I've been following along the linear regression code in the Dive into Deep Learning book. When using the pytorch engine the gradients are not cleared between backward propagation runs.
For example (this code is scala but that makes no difference):
val x = mn.arange(4.0f)
x.setRequiresGradient(true)
println(s"x=$x")
for(i <- 0 until 4) {
Using.resource(Engine.getInstance().newGradientCollector()){ gc =>
val y = x.dot(x) * 2
gc.backward(y)
}
println(s"gradient=${x.getGradient()}")
}
The following code shows:
x=ND: (4) cpu() float32 hasGradient
[0., 1., 2., 3.]
gradient=ND: (4) cpu() float32
[ 0., 4., 8., 12.]
gradient=ND: (4) cpu() float32
[ 0., 8., 16., 24.]
gradient=ND: (4) cpu() float32
[ 0., 12., 24., 36.]
gradient=ND: (4) cpu() float32
[ 0., 16., 32., 48.]
ie the gradients keep getting summed between runs when using the pytorch engine but the correct output (below) when swapping out the pytorch dependency for the mxnet one (identical code)
x=ND: (4) cpu() float32 hasGradient
[0., 1., 2., 3.]
gradient=ND: (4) cpu() float32
[ 0., 4., 8., 12.]
gradient=ND: (4) cpu() float32
[ 0., 4., 8., 12.]
gradient=ND: (4) cpu() float32
[ 0., 4., 8., 12.]
gradient=ND: (4) cpu() float32
[ 0., 4., 8., 12.]
This is under MacOS both x86 and ARM (CPU).
This is the expected behavior for PyTorch - by default if a PyTorch Tensor (or in this case a DJL NDArray backed by the PyTorch engine) enablessetRequireGradient
, the gradients are accumulated with calls to .backward()
. This behavior is documented in the PyTorch documentation. So, multiple calls to backward will result in the behavior you observed here with PyTorch.
In PyTorch you would call optimizer.zero_grad()
with the optimizer you are using in the training loop to reset the gradients back to 0. In DJL, we don't expose this method directly to users and instead we call it as part of the optimizer updates. So, for the current example you have provided here I don't think there is a way to zero out the gradients without also using a Trainer and built in Optimizer (see this example of Training MNIST for more details).
Are there specific code blocks in the D2L book you were using where this behavior is causing confusion? I can take a look and see what we need to update.
I'm not sure whether it makes sense to us to expose the zero_grad()
method since it seems like PyTorch specific behavior and we try to keep the API agnostic, but I can take a look at our code and reevaluate.
Thanks for your reply. I am new to the project and naively would expect that the different backends functioned more or less the same with respect to usage. I understand that this is probably is the case when using a Trainer etc.
I was going through the examples in section 2.5 (Automatic Differentiation) which are not run in the context of a Trainer so I was confused about the varying results. When I switched to MXNet the problem went away which I found surprising. [Obviously I'd prefer to use an ARM native JVM (without Rosetta) which the pytorch backend allows but MXNet does not).
I think putting a note in the chapter would be sufficient if you don't want to zero out the gradients in the collector.
BTW: When I didn't find an accessible zero_grad
, I found that I could zero the gradients via:
x.getGradient().subi(x.getGradient())
which was sufficient for testing purposes.
I'm glad you found a way to get around the issue - that solution certainly works @ajrnz.
I'll take a look at the D2L book and see where a good place for this disclaimer is. There are probably some other places within our jupyter notebooks and examples and see whether we need to add similar notes.
I'm closing this issue, but feel free to reopen it if you come across any other instances where this behavioral difference causes confusion.
Can we simply add a zero_grad() step in the PyTorch implementation of backward call @siddvenk ? This solution assumes that repetitive calling of zero_grad() is ok, considering in Trainer there is also a step of zero_grad()
@KexinFeng we could, but I don't think we should. In case someone is using DJL + PyTorch for distributed training, they may choose to accumulate gradients over some microbatches before updating the weights. I think we should keep backwards the same as it's implemented in PyTorch now to avoid confusing users.