stable-diffusion icon indicating copy to clipboard operation
stable-diffusion copied to clipboard

gradient checkpointing incompatible with mixed precision?

Open XavierXiao opened this issue 1 year ago • 1 comments

Hi, I found that when using gradient checkpointing, i.e., setting use_checkpoint: True, I cannot run mixed precision training (by adding precision=16 argument to Trainer). Disabling gradient check pointing solves the issue. The error log is

  File "main.py", line 861, in <module>
    trainer.fit(model, data)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 740, in fit
    self._call_and_handle_interrupt(
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run
    self._dispatch()
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
    return self._run_train()
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1319, in _run_train
    self.fit_loop.run()
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
    self.epoch_loop.run(data_fetcher)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance
    batch_output = self.batch_loop.run(batch, batch_idx)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 215, in advance
    result = self._run_optimization(
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 266, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 378, in _optimizer_step
    lightning_module.optimizer_step(
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1652, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 164, in step
    trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in optimizer_step
    self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py", line 85, in optimizer_step
    closure_result = closure()
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 155, in closure
    self._backward_fn(step_output.closure_loss)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 327, in backward_fn
    self.trainer.accelerator.backward(loss, optimizer, opt_idx)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 311, in backward
    self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 91, in backward
    model.backward(closure_loss, optimizer, *args, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1434, in backward
    loss.backward(*args, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/autograd/function.py", line 199, in apply
    return user_fn(self, *args)
  File "/home/yanq/Codes/stable-diffusion/ldm/modules/diffusionmodules/util.py", line 138, in backward
    output_tensors = ctx.run_function(*shallow_copies)
  File "/home/yanq/Codes/stable-diffusion/ldm/modules/attention.py", line 212, in _forward
    x = self.attn1(self.norm1(x)) + x
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/normalization.py", line 189, in forward
    return F.layer_norm(
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/functional.py", line 2347, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: expected scalar type Half but found Float

XavierXiao avatar Sep 15 '22 04:09 XavierXiao

I'm facing exactly the same problem, I see that you fix the error by disabling gradient checkpointing (https://github.com/rinongal/textual_inversion/blob/0a950b482d2e8f215122805d4c5901bdb4a6947f/ldm/modules/diffusionmodules/util.py#L112).

Cool! Hope we find a way to make it compatible though.

joanrod avatar Oct 08 '22 13:10 joanrod

Two methods to use AMP, I prefer 2 in my case.

  1. Add custom_fwd and custom_bwd to use gradient checkpoint with AMP, which speeds up 1.2x (A5000) and costs same memory.

ldm/modules/diffusionmodules/util.py line 119

class CheckpointFunction(torch.autograd.Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) # add this
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])

        with torch.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        return output_tensors

    @staticmethod
    @torch.cuda.amp.custom_bwd # add this
    def backward(ctx, *output_grads):
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
        with torch.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = torch.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        return (None, None) + input_grads
  1. Fully disable gradient checkpoint to use AMP only (as in rinongal/textual_inversion), which speeds up 1.4x and costs 1.5x memory.

ldm/modules/diffusionmodules/util.py line 112

def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.
    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if False: # change flag to False
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)

hepesu avatar Oct 28 '22 02:10 hepesu