Error using SFTTrainer with Dora & FSDP
System Info
peft 0.11.2.dev0 transformers 4.41.2 Python 3.12.2 trl 0.9.4
Who can help?
No response
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder - [ ] My own task or dataset (give details below)
Reproduction
Use TRL SFTTrainer, modified the peft config to accpet use_dora.
Traceback (most recent call last):
File "/ws/alpha_llms/sft/sft.py", line 211, in <module>
main()
File "/ws/alpha_llms/sft/sft.py", line 204, in main
trainer.train()
File "/root/miniconda3/lib/python3.12/site-packages/trl/trainer/sft_trainer.py", line 440, in train
output = super().train(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 1885, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 3250, in training_step
self.accelerator.backward(loss)
File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 2125, in backward
loss.backward(**kwargs)
File "/root/miniconda3/lib/python3.12/site-packages/torch/_tensor.py", line 525, in backward
torch.autograd.backward(
File "/root/miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py", line 267, in backward
_engine_run_backward(
File "/root/miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 1126, in unpack_hook
frame.check_recomputed_tensors_match(gid)
File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 901, in check_recomputed_tensors_match
raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
tensor at position 14:
saved metadata: {'shape': torch.Size([3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 17:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 19:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 29:
saved metadata: {'shape': torch.Size([512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 32:
saved metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 34:
saved metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 44:
saved metadata: {'shape': torch.Size([512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 47:
saved metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 49:
saved metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 71:
saved metadata: {'shape': torch.Size([3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 74:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 76:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 91:
saved metadata: {'shape': torch.Size([18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 94:
saved metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 96:
saved metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 107:
saved metadata: {'shape': torch.Size([18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 110:
saved metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 112:
saved metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 124:
saved metadata: {'shape': torch.Size([3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 127:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 129:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
Expected behavior
Expected to SFT finetune use trl + peft + dora
Suggestion from this helped:
https://github.com/Lightning-AI/pytorch-lightning/issues/19267#issuecomment-1888475338
by setting use_reentrant=True, training is running now. Didn't see the root cause yet.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
This issue I had and I think might be caused by wrapping the entire model for activation checkpointing which is not recommended.