peft icon indicating copy to clipboard operation
peft copied to clipboard

Error using SFTTrainer with Dora & FSDP

Open qZhang88 opened this issue 1 year ago • 2 comments

System Info

peft 0.11.2.dev0 transformers 4.41.2 Python 3.12.2 trl 0.9.4

Who can help?

No response

Information

  • [ ] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder
  • [ ] My own task or dataset (give details below)

Reproduction

Use TRL SFTTrainer, modified the peft config to accpet use_dora.

Traceback (most recent call last):
  File "/ws/alpha_llms/sft/sft.py", line 211, in <module>
    main()
  File "/ws/alpha_llms/sft/sft.py", line 204, in main
    trainer.train()
  File "/root/miniconda3/lib/python3.12/site-packages/trl/trainer/sft_trainer.py", line 440, in train
    output = super().train(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 3250, in training_step
    self.accelerator.backward(loss)
  File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 2125, in backward
    loss.backward(**kwargs)
  File "/root/miniconda3/lib/python3.12/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/root/miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/root/miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 1126, in unpack_hook
    frame.check_recomputed_tensors_match(gid)
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 901, in check_recomputed_tensors_match
    raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
tensor at position 14:
saved metadata: {'shape': torch.Size([3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 17:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 19:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 29:
saved metadata: {'shape': torch.Size([512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 32:
saved metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 34:
saved metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 44:
saved metadata: {'shape': torch.Size([512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 47:
saved metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 49:
saved metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 512]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 71:
saved metadata: {'shape': torch.Size([3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 74:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 76:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 91:
saved metadata: {'shape': torch.Size([18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 94:
saved metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 96:
saved metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 107:
saved metadata: {'shape': torch.Size([18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 110:
saved metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 112:
saved metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 18944]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 124:
saved metadata: {'shape': torch.Size([3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 127:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
tensor at position 129:
saved metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.float32, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([1, 3584]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}

Expected behavior

Expected to SFT finetune use trl + peft + dora

qZhang88 avatar Jul 06 '24 12:07 qZhang88

Suggestion from this helped:

https://github.com/Lightning-AI/pytorch-lightning/issues/19267#issuecomment-1888475338

by setting use_reentrant=True, training is running now. Didn't see the root cause yet.

qZhang88 avatar Jul 06 '24 13:07 qZhang88

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Aug 05 '24 15:08 github-actions[bot]

This issue I had and I think might be caused by wrapping the entire model for activation checkpointing which is not recommended.

lucienwalewski avatar Aug 16 '24 08:08 lucienwalewski