torchgpipe
torchgpipe copied to clipboard
convergence problem
checkpoint = (j < checkpoint_stop)
if checkpoint:
chk = Checkpointing(partition, batch)
task = Task(streams[i], compute=chk.checkpoint, finalize=chk.recompute)
del chk
else:
def compute(batch: Batch = batch, partition: nn.Sequential = partition) -> Batch:
return batch.call(partition)
task = Task(streams[i], compute=compute, finalize=None)
del compute
When I used the second method of compute not checkpoint, I found my the effect of my network become worse and it is proportional to the number of divisions.
Hi @yangpc615, thanks for the report.
I have a few questions to understand the case.
- Does your network highly depend on BatchNorm or any algorithm regarding the batch dimension?
- Is there no convergence problem if you enable checkpointing?
- Can you explain more information for your network anything else?
thanks for your reply,do you know mmdetection ?
- I tried to apply the torchgpipe to HybridTaskCascade class of mmdetection,
- After each convolution layer there will be a BatchNorm.
- Now my network doesn't work by the checkpoint way.
And I want to know how to update network by compute method not checkpoint in torchgpipe.
@yangpc615 Did you mean that your network doesn't converge both with or without checkpointing?
Anyways, if the network highly relies on BatchNorm, a large number of micro-batches may affect training just like DataParallel. See the trade-off of a number of micro-batches. There's an option for this case in GPipe. See "Deferred Batch Normalization" to get more details.
Thank you. In addition I don't understand the following code:
def depend(fork_from: Batch, join_to: Batch) -> None:
fork_from[0], phony = fork(fork_from[0])
join_to[0] = join(join_to[0], phony)
What are functions of them and what relation is their functions with the following code:
def recompute(self, batch: Batch) -> None:
"""Applies :class:`Recompute` to the batch in place."""
input_atomic = self.batch.atomic
input = tuple(self.batch)
# batch[0] is always requiring grad, because it has been passed
# checkpoint with a phony requiring grad.
batch[0], phony = fork(batch[0])
phony = Recompute.apply(phony, self.recomputed, self.rng_states,
self.function, input_atomic, *input)
batch[0] = join(batch[0], phony)
@yangpc615 That is a good question. However, I recommend making a separate issue for a new question not related to the convergence problem.
fork
and join
makes an arbitrary dependency on an autograd graph by an empty tensor called phony
. It forces the autograd engine to follow our desired execution order. Recompute
should be executed at the specific moment in backward pass, but it is not related to the actual gradient flow. Here comes phony
which is an empty tensor with size 0. We use it to avoid unnecessary gradient accumulation.
+-----------------------------+
| |
... ---> fork - - - > Recompute - - - > join ---> ...