pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

FSDPStrategy error when automatic_optimization=False

Open carlosgjs opened this issue 1 year ago • 4 comments

Bug description

A basic example with MNIST breaks when using the FSDP strategy if using automatic_optimization=False + explicit calls to manual_backward(loss).

The error seems to stem from the following sequence during training:

  1. Strategy.training_step(), redirects the forward() call on the model to training_step. See: https://github.com/Lightning-AI/pytorch-lightning/blob/b3275e05d1e6ba0347c89c2f235990614da2ec5d/src/lightning/pytorch/strategies/strategy.py#L390
  2. This calls into FullyShardedDataParallel.forward(), which:
  3. Calls _pre_forward in which it sets the handle state to FORWARD: _pre_forward handle._training_state = HandleTrainingState.FORWARD (from IDLE) a. Issues the wrapped forward() call, which is redirected to the MNISTModel.training_step b. Within the MNISTModel.training_step(), we call manual_backward(), and this eventually triggers the fsdp’s _post_backward_hook, which is expecting (asserts) the handle to be in the BACKWARD_PRE or BACKWARD_POST states, resulting in the error: https://github.com/pytorch/pytorch/blob/0c1ac4484d174d55e3cb06fd103b869cf3b34240/torch/distributed/fsdp/_runtime_utils.py#L713 From what I can tell, the FSDP strategy expects to be doing the forward pass but during this path we call manual_backward(), triggering an invalid state.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

PATH_DATASETS = "~/datasets"
BATCH_SIZE = 256
DATA_SUBSET= BATCH_SIZE

class MNISTModel(LightningModule):
    def __init__(self, auto_opt = True): 
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)
        self.automatic_optimization = auto_opt 

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        if self.automatic_optimization:
            loss = self._train_step_auto(batch)
        else:                    
            loss = self._train_step_manual(batch)

        self.log(f"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def _train_step_auto(self, batch):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss
    
    def _train_step_manual(self, batch):
        opt = self.optimizers()
        self.zero_grad()
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        # This call fails with FSDP
        self.manual_backward(loss)
        opt.step()
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
    

train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(Subset(train_ds, range(DATA_SUBSET)), batch_size=BATCH_SIZE)

def test(auto_opt: bool):
    strategy = FSDPStrategy() 
    trainer = Trainer(max_epochs=1, strategy=strategy)
    model = MNISTModel(auto_opt=auto_opt)
    trainer.fit(model, train_loader)

test(True) # automatic optimization works
test(False) # manual optimization fails

Error messages and logs

  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 701, in _post_backward_hook
    _p_assert(
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert
    traceback.print_stack()
Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got HandleTrainingState.FORWARD
Traceback (most recent call last):
  File "/src/repos/t/experiments/fsdp/repro.py", line 61, in <module>
    test(False) # fails
  File "/src/repos/t/experiments/fsdp/repro.py", line 58, in test
    trainer.fit(model, train_loader)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1032, in _run_stage
    self.fit_loop.run()
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 138, in run
    self.advance(data_fetcher)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 244, in advance
    batch_output = self.manual_optimization.run(kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/manual.py", line 94, in run
    self.advance(kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/manual.py", line 114, in advance
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 390, in training_step
    return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 642, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 635, in wrapped_forward
    out = method(*_args, **_kwargs)
  File "/src/repos/t/experiments/fsdp/repro.py", line 27, in training_step
    loss = self._train_step_manual(batch)
  File "/src/repos/t/experiments/fsdp/repro.py", line 43, in _train_step_manual
    self.manual_backward(loss)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1071, in manual_backward
    self.trainer.strategy.backward(loss, None, *args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 213, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 72, in backward
    model.backward(tensor, *args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1090, in backward
    loss.backward(*args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 701, in _post_backward_hook
    _p_assert(
  File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/distributed/utils.py", line 146, in _p_assert
    raise AssertionError(s)
AssertionError: Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got HandleTrainingState.FORWARD

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA RTX A6000
    • available: True
    • version: 12.1
  • Lightning:
    • lightning: 2.2.0.post0
    • lightning-utilities: 0.10.1
    • pytorch-lightning: 2.2.0.post0
    • torch: 2.2.1
    • torchmetrics: 1.3.1
    • torchvision: 0.17.1
  • Packages:
    • absl-py: 2.1.0
    • aiohttp: 3.9.3
    • aiosignal: 1.3.1
    • antlr4-python3-runtime: 4.9.3
    • appdirs: 1.4.4
    • async-timeout: 4.0.3
    • attrs: 23.2.0
    • av: 11.0.0
    • azure-core: 1.30.1
    • azure-identity: 1.15.0
    • azure-storage-blob: 12.19.1
    • bitsandbytes: 0.41.0
    • certifi: 2024.2.2
    • cffi: 1.16.0
    • cfgv: 3.4.0
    • charset-normalizer: 3.3.2
    • click: 8.1.7
    • contourpy: 1.2.0
    • cryptography: 42.0.5
    • cycler: 0.12.1
    • distlib: 0.3.8
    • docker-pycreds: 0.4.0
    • docstring-parser: 0.15
    • filelock: 3.13.1
    • fonttools: 4.49.0
    • frozenlist: 1.4.1
    • fsspec: 2024.2.0
    • gitdb: 4.0.11
    • gitpython: 3.1.42
    • grpcio: 1.62.0
    • hydra-core: 1.3.2
    • identify: 2.5.35
    • idna: 3.6
    • importlib-resources: 6.1.2
    • isodate: 0.6.1
    • jinja2: 3.1.3
    • jsonargparse: 4.27.5
    • kiwisolver: 1.4.5
    • lightning: 2.2.0.post0
    • lightning-utilities: 0.10.1
    • markdown: 3.5.2
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.5
    • matplotlib: 3.8.3
    • mdurl: 0.1.2
    • mpmath: 1.3.0
    • msal: 1.27.0
    • msal-extensions: 1.1.0
    • multidict: 6.0.5
    • networkx: 3.2.1
    • nodeenv: 1.8.0
    • numpy: 1.26.4
    • nvidia-cublas-cu12: 12.1.3.1
    • nvidia-cuda-cupti-cu12: 12.1.105
    • nvidia-cuda-nvrtc-cu12: 12.1.105
    • nvidia-cuda-runtime-cu12: 12.1.105
    • nvidia-cudnn-cu12: 8.9.2.26
    • nvidia-cufft-cu12: 11.0.2.54
    • nvidia-curand-cu12: 10.3.2.106
    • nvidia-cusolver-cu12: 11.4.5.107
    • nvidia-cusparse-cu12: 12.1.0.106
    • nvidia-nccl-cu12: 2.19.3
    • nvidia-nvjitlink-cu12: 12.3.101
    • nvidia-nvtx-cu12: 12.1.105
    • omegaconf: 2.3.0
    • packaging: 23.2
    • pillow: 10.2.0
    • pip: 23.3.1
    • platformdirs: 4.2.0
    • portalocker: 2.8.2
    • pre-commit: 3.6.2
    • protobuf: 4.25.3
    • psutil: 5.9.8
    • pycparser: 2.21
    • pygments: 2.17.2
    • pyjwt: 2.8.0
    • pyparsing: 3.1.1
    • python-dateutil: 2.9.0.post0
    • pytorch-lightning: 2.2.0.post0
    • pyyaml: 6.0.1
    • requests: 2.31.0
    • rich: 13.7.1
    • sentry-sdk: 1.40.6
    • setproctitle: 1.3.3
    • setuptools: 68.2.2
    • six: 1.16.0
    • smmap: 5.0.1
    • sympy: 1.12
    • tensorboard: 2.16.2
    • tensorboard-data-server: 0.7.2
    • tensorboardx: 2.6.2.2
    • torch: 2.2.1
    • torchmetrics: 1.3.1
    • torchvision: 0.17.1
    • tqdm: 4.66.2
    • triton: 2.2.0
    • typeshed-client: 2.5.1
    • typing-extensions: 4.10.0
    • urllib3: 2.2.1
    • virtualenv: 20.25.1
    • wandb: 0.16.4
    • werkzeug: 3.0.1
    • wheel: 0.41.2
    • yarl: 1.9.4
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.10.13
    • release: 6.1.0-1034-oem
    • version: #34-Ubuntu SMP PREEMPT_DYNAMIC Mon Feb 5 18:29:21 UTC 2024

More info

No response

carlosgjs avatar Mar 13 '24 22:03 carlosgjs

I'm having the same issue when running FSDP + manual backward in a different setup. I was also able to reproduce the bug using the code provided above.

Remark : OP (@carlosgjs) forgot to add the imports in his script. I'm adding them below for convenience.

from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import FSDPStrategy
import torch
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

gleize avatar Mar 24 '24 18:03 gleize

@carlosgjs I found a workaround. Essentially, we need to exit the FSDP forward function so it changes its state from FORWARD to BACKWARD_PRE. So, I moved the manual_backward to the on_train_batch_end hook.

Here is the fix.

from lightning.pytorch.core import LightningModule
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import FSDPStrategy
import torch
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

from loguru import logger

PATH_DATASETS = "~/datasets"
BATCH_SIZE = 256
DATA_SUBSET = 100 * BATCH_SIZE


class MNISTModel(LightningModule):
    def __init__(self, auto_opt=True):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)
        self.automatic_optimization = auto_opt

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        if self.automatic_optimization:
            loss = self._train_step_auto(batch)
        else:
            loss = self._train_step_manual(batch)

        self.log(f"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def _train_step_auto(self, batch):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def _train_step_manual(self, batch):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        if not self.automatic_optimization:
            self.loss = loss
        return loss

    # outputs, batch, batch_idx
    def on_train_batch_end(self, outputs, batch, batch_idx):
        if not self.automatic_optimization:
            self.zero_grad()
            # self.manual_backward(outputs["loss"]) # cannot use this because the backward graph has been flushed
            self.manual_backward(self.loss)
            self.loss = None
            self.optimizers().step()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


train_ds = MNIST(
    PATH_DATASETS,
    train=True,
    download=True,
    transform=transforms.ToTensor(),
)
train_loader = DataLoader(Subset(train_ds, range(DATA_SUBSET)), batch_size=BATCH_SIZE)


def test(auto_opt: bool):
    try:
        strategy = FSDPStrategy()
        trainer = Trainer(max_epochs=1, strategy=strategy)
        model = MNISTModel(auto_opt=auto_opt)
        trainer.fit(model, train_loader)
    except Exception as e:
        logger.opt(exception=True).debug("exception")
        raise e from e


test(True)  # automatic optimization works
test(False)  # manual optimization fails

gleize avatar Mar 24 '24 22:03 gleize

I am facing the following error following your solution. @gleize Any idea how I can resolve it?

ValueError: expected to be in states [<TrainingState.FORWARD_BACKWARD: 2>] but current state is TrainingState.IDLE

mojivalipour avatar Sep 24 '24 01:09 mojivalipour

I meet the same problem too. Any update on it?

RobertLuo1 avatar Sep 28 '24 08:09 RobertLuo1

I meet the same problem too. And I want to use manual backward in a single training step. Is there any update on it?

Linmj-Judy avatar Dec 23 '25 18:12 Linmj-Judy