torchgpipe icon indicating copy to clipboard operation
torchgpipe copied to clipboard

convergence problem

Open yangpc615 opened this issue 5 years ago • 6 comments

        checkpoint = (j < checkpoint_stop)
        if checkpoint:
            chk = Checkpointing(partition, batch)
            task = Task(streams[i], compute=chk.checkpoint, finalize=chk.recompute)
            del chk

        else:
            def compute(batch: Batch = batch, partition: nn.Sequential = partition) -> Batch:
                return batch.call(partition)
            task = Task(streams[i], compute=compute, finalize=None)
            del compute

When I used the second method of compute not checkpoint, I found my the effect of my network become worse and it is proportional to the number of divisions.

yangpc615 avatar Dec 18 '19 06:12 yangpc615

Hi @yangpc615, thanks for the report.

I have a few questions to understand the case.

  • Does your network highly depend on BatchNorm or any algorithm regarding the batch dimension?
  • Is there no convergence problem if you enable checkpointing?
  • Can you explain more information for your network anything else?

sublee avatar Dec 18 '19 07:12 sublee

thanks for your reply,do you know mmdetection ?

  • I tried to apply the torchgpipe to HybridTaskCascade class of mmdetection,
  • After each convolution layer there will be a BatchNorm.
  • Now my network doesn't work by the checkpoint way.

yangpc615 avatar Dec 18 '19 08:12 yangpc615

And I want to know how to update network by compute method not checkpoint in torchgpipe.

yangpc615 avatar Dec 18 '19 08:12 yangpc615

@yangpc615 Did you mean that your network doesn't converge both with or without checkpointing?

Anyways, if the network highly relies on BatchNorm, a large number of micro-batches may affect training just like DataParallel. See the trade-off of a number of micro-batches. There's an option for this case in GPipe. See "Deferred Batch Normalization" to get more details.

sublee avatar Dec 19 '19 08:12 sublee

Thank you. In addition I don't understand the following code:

def depend(fork_from: Batch, join_to: Batch) -> None:
    fork_from[0], phony = fork(fork_from[0])
    join_to[0] = join(join_to[0], phony)

What are functions of them and what relation is their functions with the following code:

    def recompute(self, batch: Batch) -> None:
        """Applies :class:`Recompute` to the batch in place."""
        input_atomic = self.batch.atomic
        input = tuple(self.batch)

        # batch[0] is always requiring grad, because it has been passed
        # checkpoint with a phony requiring grad.
        batch[0], phony = fork(batch[0])
        phony = Recompute.apply(phony, self.recomputed, self.rng_states,
                                self.function, input_atomic, *input)
        batch[0] = join(batch[0], phony) 

yangpc615 avatar Dec 24 '19 07:12 yangpc615

@yangpc615 That is a good question. However, I recommend making a separate issue for a new question not related to the convergence problem.

fork and join makes an arbitrary dependency on an autograd graph by an empty tensor called phony. It forces the autograd engine to follow our desired execution order. Recompute should be executed at the specific moment in backward pass, but it is not related to the actual gradient flow. Here comes phony which is an empty tensor with size 0. We use it to avoid unnecessary gradient accumulation.

           +-----------------------------+
           |                             |
... ---> fork - - - > Recompute - - - > join ---> ...

sublee avatar Dec 25 '19 14:12 sublee