pytorch-lightning copied to clipboard
mixed precision with Deepspeed
Bug description
When using mixed precision with Deepspeed, the model resulted in the error: RuntimeError: expected scalar type Float but found Half
How to reproduce the bug
class SimpleModel(LightningModule):
args: model init hyperparameters
def __init__(self, args):
self.args= args
self.pretrain_model = Bert()
self.classifier = SimpleMLP()
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
tokens, labels = x
with torch.no_grad():
embeddings = self.pretrain_model(batch_tokens)
preds, loss = self.classifier(embeddings, label)
return preds
def training_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
# Logging to TensorBoard by default
# self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
tokens, labels = x
with torch.no_grad():
embeddings = self.pretrain_model(batch_tokens)
preds, loss = self.classifier(embeddings, label)
self.log("training_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx):
tokens, labels = x
with torch.no_grad():
embeddings = self.pretrain_model(batch_tokens)
preds, loss = self.classifier(embeddings, label)
self.log("val_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
def test_step(self, batch, batch_idx):
tokens, labels = x
with torch.no_grad():
embeddings = self.pretrain_model(batch_tokens)
preds, loss = self.classifier(embeddings, label)
self.log("test_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.classifier.parameters(), lr=self.args['lr'], weight_decay=0.01, eps=1e-6)
return optimizer
model = SimpleModel(args=args)
trainer = pl.Trainer(devices=4,strategy="deepspeed_stage_3", precision=16, max_epochs=20, accelerator='gpu'), datamodule=dataset)
Error messages and logs
# Error messages and logs here please
Traceback (most recent call last):
File "/home/wanglei/data/alphafold_db/ProtBert/benchmark/src/", line 228, in <module>
File "/home/wanglei/data/alphafold_db/ProtBert/benchmark/src/", line 196, in main, datamodule=dataset)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 770, in fit
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 723, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 811, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 1236, in _run
results = self._run_stage()
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 1323, in _run_stage
return self._run_train()
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 1345, in _run_train
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 1413, in _run_sanity_check
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/", line 204, in run
self.advance(*args, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/", line 155, in advance
dl_outputs =, dl_max_batches, kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/", line 204, in run
self.advance(*args, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/", line 128, in advance
output = self._evaluation_step(**kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/", line 226, in _evaluation_step
output = self.trainer._call_strategy_hook("validation_step", *kwargs.values())
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 1765, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/strategies/", line 906, in validation_step
return self.model(*args, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/deepspeed/utils/", line 11, in wrapped_fn
return func(*args, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/deepspeed/runtime/", line 1599, in forward
loss = self.module(*inputs, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/", line 1128, in _call_impl
result = forward_call(*input, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/strategies/", line 80, in forward
return super().forward(*inputs, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/overrides/", line 93, in forward
return self.module.validation_step(*inputs, **kwargs)
File "/home/wanglei/data/alphafold_db/ProtBert/benchmark/src/", line 129, in validation_step
protein_dict = self.esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/", line 1128, in _call_impl
result = forward_call(*input, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/esm/", line 140, in forward
x = self.emb_layer_norm_before(x)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/", line 1128, in _call_impl
result = forward_call(*input, **kwargs)
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/", line 189, in forward
return F.layer_norm(
File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/", line 2486, in layer_norm
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: expected scalar type Float but found Half
Current Environment
- GPU:
- GeForce RTX 3090
- GeForce RTX 3090
- GeForce RTX 3090
- GeForce RTX 3090
- GeForce RTX 3090
- GeForce RTX 3090
- GeForce RTX 3090
- GeForce RTX 3090
- available: True
- version: 11.3
* Lightning:
- pytorch-lightning: 1.6.5
- torch: 1.11.0
- torchaudio: 0.11.0
- torchinfo: 1.7.0
- torchmetrics: 0.10.0
- torchvision: 0.12.0
* Packages:
- absl-py: 1.0.0
- aiohttp: 3.8.1
- aiosignal: 1.2.0
- asttokens: 2.0.5
- async-timeout: 4.0.2
- attrs: 21.4.0
- backcall: 0.2.0
- biopython: 1.79
- brotlipy: 0.7.0
- cached-property: 1.5.2
- cachetools: 5.0.0
- certifi: 2022.6.15
- cffi: 1.14.4
- charset-normalizer: 2.1.0
- click: 8.1.3
- cryptography: 37.0.2
- cycler: 0.11.0
- decorator: 5.1.1
- deepspeed: 0.6.6
- deprecated: 1.2.13
- distlib: 0.3.4
- docker-pycreds: 0.4.0
- einops: 0.4.0
- executing: 0.8.3
- fair-esm: 0.4.2
- fairscale: 0.4.6
- filelock: 3.7.0
- fonttools: 4.29.1
- frozenlist: 1.3.0
- fsspec: 2022.2.0
- future: 0.18.2
- gitdb: 4.0.9
- gitpython: 3.1.27
- google-auth: 2.6.0
- google-auth-oauthlib: 0.4.6
- grpcio: 1.44.0
- h5py: 3.6.0
- hjson: 3.0.2
- huggingface-hub: 0.6.0
- idna: 3.3
- importlib-metadata: 4.11.2
- infinibatch: 0.1.0
- ipython: 8.1.0
- jedi: 0.18.1
- joblib: 1.1.0
- kiwisolver: 1.3.2
- lmdb: 1.3.0
- lxml: 4.8.0
- markdown: 3.3.6
- matplotlib: 3.5.1
- matplotlib-inline: 0.1.3
- mkl-fft: 1.3.1
- mkl-random: 1.2.2
- mkl-service: 2.4.0
- multidict: 6.0.2
- ninja:
- numpy: 1.22.3
- oauthlib: 3.2.0
- packaging: 21.3
- pandas: 1.4.2
- parso: 0.8.3
- pathtools: 0.1.2
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 9.1.1
- pip: 22.1.2
- platformdirs: 2.5.2
- plip: 2.2.2
- promise: 2.3
- prompt-toolkit: 3.0.28
- protobuf: 3.19.4
- psutil: 5.9.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py-cpuinfo: 8.0.0
- pyasn1: 0.4.8
- pyasn1-modules: 0.2.8
- pycparser: 2.21
- pydantic: 1.9.1
- pydeprecate: 0.3.1
- pygments: 2.11.2
- pyopenssl: 22.0.0
- pyparsing: 3.0.7
- pysocks: 1.7.1
- python-dateutil: 2.8.2
- pytorch-lightning: 1.6.5
- pytz: 2022.1
- pyyaml: 6.0
- redis: 4.3.1
- regex: 2022.4.24
- requests: 2.28.1
- requests-oauthlib: 1.3.1
- rsa: 4.8
- scikit-learn: 1.1.1
- scipy: 1.8.0
- sentencepiece: 0.1.97
- sentry-sdk: 1.5.12
- setproctitle: 1.2.3
- setuptools: 62.6.0
- shortuuid: 1.0.9
- six: 1.16.0
- smmap: 5.0.0
- stack-data: 0.2.0
- tensorboard: 2.8.0
- tensorboard-data-server: 0.6.1
- tensorboard-plugin-wit: 1.8.1
- threadpoolctl: 3.1.0
- tokenizers: 0.12.1
- torch: 1.11.0
- torchaudio: 0.11.0
- torchinfo: 1.7.0
- torchmetrics: 0.10.0
- torchvision: 0.12.0
- tqdm: 4.63.0
- traitlets: 5.1.1
- transformers: 4.21.2
- triton: 1.0.0
- typing-extensions: 4.3.0
- urllib3: 1.26.9
- virtualenv: 20.14.1
- wandb: 0.12.16
- wcwidth: 0.2.5
- werkzeug: 2.0.3
- wheel: 0.37.1
- wrapt: 1.14.1
- yarl: 1.7.2
- zipp: 3.7.0
* System:
- OS: Linux
- architecture:
- 64bit
- processor: x86_64
- python: 3.8.0
- version: #1 SMP Thu Nov 8 23:39:32 UTC 2018
More info
No response
cc @awaelchli
Hi @wangleiofficial, I met the problem same with you. Do you fix it?
@Line290 Not yet,i guess the part parameters(Pretrained model) are not handled correctly.
I've got the same problem, any fixes yet?
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!
+1, deepspeed_stage2 meets the same error.
Similar problem with deepspeed_stage_1.
File "python3.9/site-packages/deepspeed/runtime/zero/", line 2019, in backward self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) File "python3.9/site-packages/deepspeed/runtime/fp16/", line 63, in backward scaled_loss.backward(retain_graph=retain_graph) File "python3.9/site-packages/torch/", line 487, in backward torch.autograd.backward( File "python3.9/site-packages/torch/autograd/", line 200, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Found dtype Float but expected Half