pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

mixed precision with Deepspeed

Open wangleiofficial opened this issue 2 years ago • 6 comments

Bug description

When using mixed precision with Deepspeed, the model resulted in the error: RuntimeError: expected scalar type Float but found Half.

How to reproduce the bug

class SimpleModel(LightningModule):
    """SimpleModel

    Args:
        args: model init hyperparameters
    """
    def __init__(self, args):
        super().__init__()
        self.args= args
        self.save_hyperparameters(args)
        self.pretrain_model = Bert()
        self.classifier = SimpleMLP()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        tokens, labels = x
        with torch.no_grad():
            embeddings = self.pretrain_model(batch_tokens)

        preds, loss = self.classifier(embeddings, label)
        return preds

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        
        # Logging to TensorBoard by default
        # self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
        tokens, labels = x
        with torch.no_grad():
            embeddings = self.pretrain_model(batch_tokens)

        preds, loss = self.classifier(embeddings, label)
        self.log("training_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        tokens, labels = x
        with torch.no_grad():
            embeddings = self.pretrain_model(batch_tokens)

        preds, loss = self.classifier(embeddings, label)
        self.log("val_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
    
    def test_step(self, batch, batch_idx):
        tokens, labels = x
        with torch.no_grad():
            embeddings = self.pretrain_model(batch_tokens)

        preds, loss = self.classifier(embeddings, label)
        self.log("test_loss", loss, on_epoch=True, on_step=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.classifier.parameters(), lr=self.args['lr'], weight_decay=0.01, eps=1e-6)
        return optimizer
model = SimpleModel(args=args)
trainer = pl.Trainer(devices=4,strategy="deepspeed_stage_3", precision=16, max_epochs=20, accelerator='gpu')
trainer.fit(model, datamodule=dataset)

Error messages and logs


# Error messages and logs here please
Traceback (most recent call last):
  File "/home/wanglei/data/alphafold_db/ProtBert/benchmark/src/esm_contactmap_pl.py", line 228, in <module>
    main(params)
  File "/home/wanglei/data/alphafold_db/ProtBert/benchmark/src/esm_contactmap_pl.py", line 196, in main
    trainer.fit(model, datamodule=dataset)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
    self._call_and_handle_interrupt(
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run
    results = self._run_stage()
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage
    return self._run_train()
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1345, in _run_train
    self._run_sanity_check()
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1413, in _run_sanity_check
    val_loop.run()
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 155, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 128, in advance
    output = self._evaluation_step(**kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 226, in _evaluation_step
    output = self.trainer._call_strategy_hook("validation_step", *kwargs.values())
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1765, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/strategies/deepspeed.py", line 906, in validation_step
    return self.model(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 11, in wrapped_fn
    return func(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1599, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/strategies/deepspeed.py", line 80, in forward
    return super().forward(*inputs, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 93, in forward
    return self.module.validation_step(*inputs, **kwargs)
  File "/home/wanglei/data/alphafold_db/ProtBert/benchmark/src/esm_contactmap_pl.py", line 129, in validation_step
    protein_dict = self.esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/esm/model.py", line 140, in forward
    x = self.emb_layer_norm_before(x)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/normalization.py", line 189, in forward
    return F.layer_norm(
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/functional.py", line 2486, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: expected scalar type Float but found Half

Environment

Current Environment
* CUDA:
	- GPU:
		- GeForce RTX 3090
		- GeForce RTX 3090
		- GeForce RTX 3090
		- GeForce RTX 3090
		- GeForce RTX 3090
		- GeForce RTX 3090
		- GeForce RTX 3090
		- GeForce RTX 3090
	- available:         True
	- version:           11.3
* Lightning:
	- pytorch-lightning: 1.6.5
	- torch:             1.11.0
	- torchaudio:        0.11.0
	- torchinfo:         1.7.0
	- torchmetrics:      0.10.0
	- torchvision:       0.12.0
* Packages:
	- absl-py:           1.0.0
	- aiohttp:           3.8.1
	- aiosignal:         1.2.0
	- asttokens:         2.0.5
	- async-timeout:     4.0.2
	- attrs:             21.4.0
	- backcall:          0.2.0
	- biopython:         1.79
	- brotlipy:          0.7.0
	- cached-property:   1.5.2
	- cachetools:        5.0.0
	- certifi:           2022.6.15
	- cffi:              1.14.4
	- charset-normalizer: 2.1.0
	- click:             8.1.3
	- cryptography:      37.0.2
	- cycler:            0.11.0
	- decorator:         5.1.1
	- deepspeed:         0.6.6
	- deprecated:        1.2.13
	- distlib:           0.3.4
	- docker-pycreds:    0.4.0
	- einops:            0.4.0
	- executing:         0.8.3
	- fair-esm:          0.4.2
	- fairscale:         0.4.6
	- filelock:          3.7.0
	- fonttools:         4.29.1
	- frozenlist:        1.3.0
	- fsspec:            2022.2.0
	- future:            0.18.2
	- gitdb:             4.0.9
	- gitpython:         3.1.27
	- google-auth:       2.6.0
	- google-auth-oauthlib: 0.4.6
	- grpcio:            1.44.0
	- h5py:              3.6.0
	- hjson:             3.0.2
	- huggingface-hub:   0.6.0
	- idna:              3.3
	- importlib-metadata: 4.11.2
	- infinibatch:       0.1.0
	- ipython:           8.1.0
	- jedi:              0.18.1
	- joblib:            1.1.0
	- kiwisolver:        1.3.2
	- lmdb:              1.3.0
	- lxml:              4.8.0
	- markdown:          3.3.6
	- matplotlib:        3.5.1
	- matplotlib-inline: 0.1.3
	- mkl-fft:           1.3.1
	- mkl-random:        1.2.2
	- mkl-service:       2.4.0
	- multidict:         6.0.2
	- ninja:             1.10.2.3
	- numpy:             1.22.3
	- oauthlib:          3.2.0
	- packaging:         21.3
	- pandas:            1.4.2
	- parso:             0.8.3
	- pathtools:         0.1.2
	- pexpect:           4.8.0
	- pickleshare:       0.7.5
	- pillow:            9.1.1
	- pip:               22.1.2
	- platformdirs:      2.5.2
	- plip:              2.2.2
	- promise:           2.3
	- prompt-toolkit:    3.0.28
	- protobuf:          3.19.4
	- psutil:            5.9.0
	- ptyprocess:        0.7.0
	- pure-eval:         0.2.2
	- py-cpuinfo:        8.0.0
	- pyasn1:            0.4.8
	- pyasn1-modules:    0.2.8
	- pycparser:         2.21
	- pydantic:          1.9.1
	- pydeprecate:       0.3.1
	- pygments:          2.11.2
	- pyopenssl:         22.0.0
	- pyparsing:         3.0.7
	- pysocks:           1.7.1
	- python-dateutil:   2.8.2
	- pytorch-lightning: 1.6.5
	- pytz:              2022.1
	- pyyaml:            6.0
	- redis:             4.3.1
	- regex:             2022.4.24
	- requests:          2.28.1
	- requests-oauthlib: 1.3.1
	- rsa:               4.8
	- scikit-learn:      1.1.1
	- scipy:             1.8.0
	- sentencepiece:     0.1.97
	- sentry-sdk:        1.5.12
	- setproctitle:      1.2.3
	- setuptools:        62.6.0
	- shortuuid:         1.0.9
	- six:               1.16.0
	- smmap:             5.0.0
	- stack-data:        0.2.0
	- tensorboard:       2.8.0
	- tensorboard-data-server: 0.6.1
	- tensorboard-plugin-wit: 1.8.1
	- threadpoolctl:     3.1.0
	- tokenizers:        0.12.1
	- torch:             1.11.0
	- torchaudio:        0.11.0
	- torchinfo:         1.7.0
	- torchmetrics:      0.10.0
	- torchvision:       0.12.0
	- tqdm:              4.63.0
	- traitlets:         5.1.1
	- transformers:      4.21.2
	- triton:            1.0.0
	- typing-extensions: 4.3.0
	- urllib3:           1.26.9
	- virtualenv:        20.14.1
	- wandb:             0.12.16
	- wcwidth:           0.2.5
	- werkzeug:          2.0.3
	- wheel:             0.37.1
	- wrapt:             1.14.1
	- yarl:              1.7.2
	- zipp:              3.7.0
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         x86_64
	- python:            3.8.0
	- version:           #1 SMP Thu Nov 8 23:39:32 UTC 2018

More info

No response

cc @awaelchli

wangleiofficial avatar Oct 18 '22 12:10 wangleiofficial

Hi @wangleiofficial, I met the problem same with you. Do you fix it?

Line290 avatar Oct 28 '22 14:10 Line290

@Line290 Not yet,i guess the part parameters(Pretrained model) are not handled correctly.

wangleiofficial avatar Oct 29 '22 08:10 wangleiofficial

I've got the same problem, any fixes yet?

FarzanT avatar Dec 15 '22 17:12 FarzanT

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

stale[bot] avatar Jan 21 '23 12:01 stale[bot]

+1, deepspeed_stage2 meets the same error.

ewrfcas avatar Sep 10 '23 09:09 ewrfcas

Similar problem with deepspeed_stage_1.

File "python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2019, in backward self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) File "python3.9/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward scaled_loss.backward(retain_graph=retain_graph) File "python3.9/site-packages/torch/_tensor.py", line 487, in backward torch.autograd.backward( File "python3.9/site-packages/torch/autograd/init.py", line 200, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Found dtype Float but expected Half

YTEP-ZHI avatar Apr 19 '24 15:04 YTEP-ZHI