stable-diffusion
stable-diffusion copied to clipboard
gradient checkpointing incompatible with mixed precision?
Hi, I found that when using gradient checkpointing, i.e., setting use_checkpoint: True
, I cannot run mixed precision training (by adding precision=16
argument to Trainer). Disabling gradient check pointing solves the issue. The error log is
File "main.py", line 861, in <module>
trainer.fit(model, data)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 740, in fit
self._call_and_handle_interrupt(
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run
self._dispatch()
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
return self._run_train()
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1319, in _run_train
self.fit_loop.run()
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
self.epoch_loop.run(data_fetcher)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance
batch_output = self.batch_loop.run(batch, batch_idx)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 215, in advance
result = self._run_optimization(
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 266, in _run_optimization
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 378, in _optimizer_step
lightning_module.optimizer_step(
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1652, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 164, in step
trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in optimizer_step
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py", line 85, in optimizer_step
closure_result = closure()
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in __call__
self._result = self.closure(*args, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 155, in closure
self._backward_fn(step_output.closure_loss)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 327, in backward_fn
self.trainer.accelerator.backward(loss, optimizer, opt_idx)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 311, in backward
self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 91, in backward
model.backward(closure_loss, optimizer, *args, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1434, in backward
loss.backward(*args, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
Variable._execution_engine.run_backward(
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/autograd/function.py", line 199, in apply
return user_fn(self, *args)
File "/home/yanq/Codes/stable-diffusion/ldm/modules/diffusionmodules/util.py", line 138, in backward
output_tensors = ctx.run_function(*shallow_copies)
File "/home/yanq/Codes/stable-diffusion/ldm/modules/attention.py", line 212, in _forward
x = self.attn1(self.norm1(x)) + x
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/normalization.py", line 189, in forward
return F.layer_norm(
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/functional.py", line 2347, in layer_norm
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: expected scalar type Half but found Float
I'm facing exactly the same problem, I see that you fix the error by disabling gradient checkpointing (https://github.com/rinongal/textual_inversion/blob/0a950b482d2e8f215122805d4c5901bdb4a6947f/ldm/modules/diffusionmodules/util.py#L112).
Cool! Hope we find a way to make it compatible though.
Two methods to use AMP, I prefer 2 in my case.
- Add
custom_fwd
andcustom_bwd
to use gradient checkpoint with AMP, which speeds up 1.2x (A5000) and costs same memory.
ldm/modules/diffusionmodules/util.py line 119
class CheckpointFunction(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) # add this
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
@torch.cuda.amp.custom_bwd # add this
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
- Fully disable gradient checkpoint to use AMP only (as in rinongal/textual_inversion), which speeds up 1.4x and costs 1.5x memory.
ldm/modules/diffusionmodules/util.py line 112
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if False: # change flag to False
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)