🐛 Describe the bug
when i train stable-diffusion using precision: 16, this error arise
Traceback (most recent call last):
File "/work/repo/main.py", line 919, in
raise err
File "/work/repo/main.py", line 901, in
trainer.fit(model, data)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
self._run(model)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 918, in _run
self._dispatch()
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _dispatch
self.accelerator.start_training(self)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
self.training_type_plugin.start_training(trainer)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
self._results = trainer.run_stage()
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 996, in run_stage
return self._run_train()
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train
self.fit_loop.run()
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
epoch_output = self.epoch_loop.run(train_dataloader)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 130, in advance
batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 101, in run
super().run(batch, batch_idx, dataloader_idx)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 148, in advance
result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 202, in _run_optimization
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 396, in _optimizer_step
model_ref.optimizer_step(
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py", line 1618, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 209, in step
self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 129, in __optimizer_step
trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 292, in optimizer_step
make_optimizer_step = self.precision_plugin.pre_optimizer_step(
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/native_amp.py", line 59, in pre_optimizer_step
result = lambda_closure()
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 236, in _training_step_and_backward_closure
result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 549, in training_step_and_backward
self.backward(result, optimizer, opt_idx)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 590, in backward
result.closure_loss = self.trainer.accelerator.backward(result.closure_loss, optimizer, *args, **kwargs)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 276, in backward
self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 78, in backward
model.backward(closure_loss, optimizer, *args, **kwargs)
File "/root/conda/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py", line 1481, in backward
loss.backward(*args, **kwargs)
File "/root/conda/lib/python3.9/site-packages/torch/_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/root/conda/lib/python3.9/site-packages/torch/autograd/init.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/root/conda/lib/python3.9/site-packages/torch/autograd/function.py", line 253, in apply
return user_fn(self, *args)
File "/work/repo/ldm/modules/diffusionmodules/util.py", line 138, in backward
output_tensors = ctx.run_function(*shallow_copies)
File "/work/repo/ldm/modules/attention.py", line 215, in _forward
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
File "/root/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/root/conda/lib/python3.9/site-packages/torch/nn/modules/normalization.py", line 189, in forward
return F.layer_norm(
File "/root/conda/lib/python3.9/site-packages/torch/nn/functional.py", line 2503, in layer_norm
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: expected scalar type Half but found Float
Environment
pip install colossalai==0.1.10+torch1.21cu11.3 -f https://release.colossalai.org
when i modify the code in /root/conda/lib/python3.9/site-packages/torch/nn/normalization.py as
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
with
return F.layer_norm(
input, self.normalized_shape, self.weight.type(input.dtype), self.bias.type(input.dtype), self.eps)
then the same error arise in the F.linear