Weird error when using activation checkpointing for FSDPStrategy
I'm training tinyllama with 8 A40s. Everything goes very smooth until I want to increase the micro batch size for better computation to communication ratio.
I follow the official tutorial of lit gpt by passing activation_checkpointing_policy={Block} into FSDPStrategy. The modified setup is also attached below.
def setup(
devices: int = 8,
train_data_dir: Path = Path("data/redpajama_sample"),
val_data_dir: Optional[Path] = None,
precision: Optional[str] = None,
tpu: bool = False,
resume: Union[bool, Path] = False,
) -> None:
precision = precision or get_default_supported_precision(training=True, tpu=tpu)
if devices > 1:
if tpu:
...
else:
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
sharding_strategy="FULL_SHARD",
)
else:
strategy = "auto"
But I got some strange errors about the activation checkpointing. Could someone shed some light on this, anything informative is a big help for me.
Traceback (most recent call last):
File "pretrain/tinyllama.py", line 424, in <module>
CLI(setup)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/_cli.py", line 96, in CLI
return _run_component(components, cfg_init)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/_cli.py", line 181, in _run_component
return component(**cfg)
File "pretrain/tinyllama.py", line 108, in setup
main(fabric, train_data_dir, val_data_dir, resume)
File "pretrain/tinyllama.py", line 160, in main
train(fabric, state, train_dataloader, val_dataloader, monitor, resume)
File "pretrain/tinyllama.py", line 244, in train
fabric.backward(loss / gradient_accumulation_steps)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/fabric.py", line 422, in backward
self._strategy.backward(tensor, module, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/strategies/strategy.py", line 192, in backward
self.precision.backward(tensor, module, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fsdp.py", line 126, in backward
super().backward(tensor, model, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/precision.py", line 107, in backward
tensor.backward(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 1075, in unpack_hook
frame.check_recomputed_tensors_match(gid)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 812, in check_recomputed_tensors_match
raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: A different number of tensors was saved during the original forward and recomputation.
Number of tensors saved during forward: 27
Number of tensors saved during recomputation: 8
Hi! The setup that you shared in your first snippet is very different to the setup in https://github.com/Lightning-AI/lit-gpt/blob/main/pretrain/tinyllama.py#L66. Can you share all changes that you made to the repo? You can do:
git diff > changes.diff
And then post the changes.diff file here.
cc @awaelchli in case you are familiar
Same problem here.