FSDPStrategy error when automatic_optimization=False
Bug description
A basic example with MNIST breaks when using the FSDP strategy if using automatic_optimization=False + explicit calls to manual_backward(loss).
The error seems to stem from the following sequence during training:
Strategy.training_step(), redirects theforward()call on the model totraining_step. See: https://github.com/Lightning-AI/pytorch-lightning/blob/b3275e05d1e6ba0347c89c2f235990614da2ec5d/src/lightning/pytorch/strategies/strategy.py#L390- This calls into
FullyShardedDataParallel.forward(), which: - Calls
_pre_forwardin which it sets the handle state toFORWARD:_pre_forward handle._training_state = HandleTrainingState.FORWARD (from IDLE)a. Issues the wrappedforward()call, which is redirected to theMNISTModel.training_stepb. Within theMNISTModel.training_step(), we callmanual_backward(), and this eventually triggers the fsdp’s_post_backward_hook, which is expecting (asserts) the handle to be in theBACKWARD_PREorBACKWARD_POSTstates, resulting in the error: https://github.com/pytorch/pytorch/blob/0c1ac4484d174d55e3cb06fd103b869cf3b34240/torch/distributed/fsdp/_runtime_utils.py#L713 From what I can tell, the FSDP strategy expects to be doing the forward pass but during this path we callmanual_backward(), triggering an invalid state.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
PATH_DATASETS = "~/datasets"
BATCH_SIZE = 256
DATA_SUBSET= BATCH_SIZE
class MNISTModel(LightningModule):
def __init__(self, auto_opt = True):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
self.automatic_optimization = auto_opt
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
if self.automatic_optimization:
loss = self._train_step_auto(batch)
else:
loss = self._train_step_manual(batch)
self.log(f"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def _train_step_auto(self, batch):
x, y = batch
loss = F.cross_entropy(self(x), y)
return loss
def _train_step_manual(self, batch):
opt = self.optimizers()
self.zero_grad()
x, y = batch
loss = F.cross_entropy(self(x), y)
# This call fails with FSDP
self.manual_backward(loss)
opt.step()
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(Subset(train_ds, range(DATA_SUBSET)), batch_size=BATCH_SIZE)
def test(auto_opt: bool):
strategy = FSDPStrategy()
trainer = Trainer(max_epochs=1, strategy=strategy)
model = MNISTModel(auto_opt=auto_opt)
trainer.fit(model, train_loader)
test(True) # automatic optimization works
test(False) # manual optimization fails
Error messages and logs
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 701, in _post_backward_hook
_p_assert(
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert
traceback.print_stack()
Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got HandleTrainingState.FORWARD
Traceback (most recent call last):
File "/src/repos/t/experiments/fsdp/repro.py", line 61, in <module>
test(False) # fails
File "/src/repos/t/experiments/fsdp/repro.py", line 58, in test
trainer.fit(model, train_loader)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
call._call_and_handle_interrupt(
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
results = self._run_stage()
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1032, in _run_stage
self.fit_loop.run()
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
self.advance()
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
self.epoch_loop.run(self._data_fetcher)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 138, in run
self.advance(data_fetcher)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 244, in advance
batch_output = self.manual_optimization.run(kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/manual.py", line 94, in run
self.advance(kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/manual.py", line 114, in advance
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 390, in training_step
return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 642, in __call__
wrapper_output = wrapper_module(*args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 635, in wrapped_forward
out = method(*_args, **_kwargs)
File "/src/repos/t/experiments/fsdp/repro.py", line 27, in training_step
loss = self._train_step_manual(batch)
File "/src/repos/t/experiments/fsdp/repro.py", line 43, in _train_step_manual
self.manual_backward(loss)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1071, in manual_backward
self.trainer.strategy.backward(loss, None, *args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 213, in backward
self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 72, in backward
model.backward(tensor, *args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1090, in backward
loss.backward(*args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 701, in _post_backward_hook
_p_assert(
File "/src/miniconda3/envs/diffusion/lib/python3.10/site-packages/torch/distributed/utils.py", line 146, in _p_assert
raise AssertionError(s)
AssertionError: Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got HandleTrainingState.FORWARD
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA RTX A6000
- available: True
- version: 12.1
- GPU:
- Lightning:
- lightning: 2.2.0.post0
- lightning-utilities: 0.10.1
- pytorch-lightning: 2.2.0.post0
- torch: 2.2.1
- torchmetrics: 1.3.1
- torchvision: 0.17.1
- Packages:
- absl-py: 2.1.0
- aiohttp: 3.9.3
- aiosignal: 1.3.1
- antlr4-python3-runtime: 4.9.3
- appdirs: 1.4.4
- async-timeout: 4.0.3
- attrs: 23.2.0
- av: 11.0.0
- azure-core: 1.30.1
- azure-identity: 1.15.0
- azure-storage-blob: 12.19.1
- bitsandbytes: 0.41.0
- certifi: 2024.2.2
- cffi: 1.16.0
- cfgv: 3.4.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- contourpy: 1.2.0
- cryptography: 42.0.5
- cycler: 0.12.1
- distlib: 0.3.8
- docker-pycreds: 0.4.0
- docstring-parser: 0.15
- filelock: 3.13.1
- fonttools: 4.49.0
- frozenlist: 1.4.1
- fsspec: 2024.2.0
- gitdb: 4.0.11
- gitpython: 3.1.42
- grpcio: 1.62.0
- hydra-core: 1.3.2
- identify: 2.5.35
- idna: 3.6
- importlib-resources: 6.1.2
- isodate: 0.6.1
- jinja2: 3.1.3
- jsonargparse: 4.27.5
- kiwisolver: 1.4.5
- lightning: 2.2.0.post0
- lightning-utilities: 0.10.1
- markdown: 3.5.2
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- matplotlib: 3.8.3
- mdurl: 0.1.2
- mpmath: 1.3.0
- msal: 1.27.0
- msal-extensions: 1.1.0
- multidict: 6.0.5
- networkx: 3.2.1
- nodeenv: 1.8.0
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu12: 8.9.2.26
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-nccl-cu12: 2.19.3
- nvidia-nvjitlink-cu12: 12.3.101
- nvidia-nvtx-cu12: 12.1.105
- omegaconf: 2.3.0
- packaging: 23.2
- pillow: 10.2.0
- pip: 23.3.1
- platformdirs: 4.2.0
- portalocker: 2.8.2
- pre-commit: 3.6.2
- protobuf: 4.25.3
- psutil: 5.9.8
- pycparser: 2.21
- pygments: 2.17.2
- pyjwt: 2.8.0
- pyparsing: 3.1.1
- python-dateutil: 2.9.0.post0
- pytorch-lightning: 2.2.0.post0
- pyyaml: 6.0.1
- requests: 2.31.0
- rich: 13.7.1
- sentry-sdk: 1.40.6
- setproctitle: 1.3.3
- setuptools: 68.2.2
- six: 1.16.0
- smmap: 5.0.1
- sympy: 1.12
- tensorboard: 2.16.2
- tensorboard-data-server: 0.7.2
- tensorboardx: 2.6.2.2
- torch: 2.2.1
- torchmetrics: 1.3.1
- torchvision: 0.17.1
- tqdm: 4.66.2
- triton: 2.2.0
- typeshed-client: 2.5.1
- typing-extensions: 4.10.0
- urllib3: 2.2.1
- virtualenv: 20.25.1
- wandb: 0.16.4
- werkzeug: 3.0.1
- wheel: 0.41.2
- yarl: 1.9.4
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.13
- release: 6.1.0-1034-oem
- version: #34-Ubuntu SMP PREEMPT_DYNAMIC Mon Feb 5 18:29:21 UTC 2024
More info
No response
I'm having the same issue when running FSDP + manual backward in a different setup. I was also able to reproduce the bug using the code provided above.
Remark : OP (@carlosgjs) forgot to add the imports in his script. I'm adding them below for convenience.
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import FSDPStrategy
import torch
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
@carlosgjs I found a workaround.
Essentially, we need to exit the FSDP forward function so it changes its state from FORWARD to BACKWARD_PRE.
So, I moved the manual_backward to the on_train_batch_end hook.
Here is the fix.
from lightning.pytorch.core import LightningModule
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import FSDPStrategy
import torch
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from loguru import logger
PATH_DATASETS = "~/datasets"
BATCH_SIZE = 256
DATA_SUBSET = 100 * BATCH_SIZE
class MNISTModel(LightningModule):
def __init__(self, auto_opt=True):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
self.automatic_optimization = auto_opt
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
if self.automatic_optimization:
loss = self._train_step_auto(batch)
else:
loss = self._train_step_manual(batch)
self.log(f"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def _train_step_auto(self, batch):
x, y = batch
loss = F.cross_entropy(self(x), y)
return loss
def _train_step_manual(self, batch):
x, y = batch
loss = F.cross_entropy(self(x), y)
if not self.automatic_optimization:
self.loss = loss
return loss
# outputs, batch, batch_idx
def on_train_batch_end(self, outputs, batch, batch_idx):
if not self.automatic_optimization:
self.zero_grad()
# self.manual_backward(outputs["loss"]) # cannot use this because the backward graph has been flushed
self.manual_backward(self.loss)
self.loss = None
self.optimizers().step()
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
train_ds = MNIST(
PATH_DATASETS,
train=True,
download=True,
transform=transforms.ToTensor(),
)
train_loader = DataLoader(Subset(train_ds, range(DATA_SUBSET)), batch_size=BATCH_SIZE)
def test(auto_opt: bool):
try:
strategy = FSDPStrategy()
trainer = Trainer(max_epochs=1, strategy=strategy)
model = MNISTModel(auto_opt=auto_opt)
trainer.fit(model, train_loader)
except Exception as e:
logger.opt(exception=True).debug("exception")
raise e from e
test(True) # automatic optimization works
test(False) # manual optimization fails
I am facing the following error following your solution. @gleize Any idea how I can resolve it?
ValueError: expected to be in states [<TrainingState.FORWARD_BACKWARD: 2>] but current state is TrainingState.IDLE
I meet the same problem too. Any update on it?
I meet the same problem too. And I want to use manual backward in a single training step. Is there any update on it?