verl icon indicating copy to clipboard operation
verl copied to clipboard

[Feature Request] Add state saving/loading support for SequentialSampler

Open zpqiu opened this issue 10 months ago • 3 comments

Description

Currently, PyTorch's SequentialSampler does not support saving and restoring its state during checkpoint operations. This makes it difficult to resume training from a checkpoint when using SequentialSampler, as the sampler will always restart from the beginning of the dataset.

Current Behavior

  • SequentialSampler has no mechanism to track or restore its current position
  • When loading from a checkpoint, the sampler always starts from index 0
  • No built-in methods for state dict serialization (state_dict() and load_state_dict())

Desired Behavior

  • Add state tracking to SequentialSampler
  • Support saving/loading the current index through state dict operations
  • Allow resuming iteration from the last checkpoint position

Example Use Case

# Training loop
dataloader = DataLoader(dataset, sampler=SequentialSampler(dataset))
# ... training ...

# Save checkpoint
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'sampler': dataloader.sampler.state_dict()  # Currently not possible
}
torch.save(checkpoint, 'checkpoint.pt')

# Resume training
checkpoint = torch.load('checkpoint.pt')
dataloader.sampler.load_state_dict(checkpoint['sampler'])  # Currently not possible

Current Workarounds

Currently, users need to implement workarounds such as:

  1. Creating custom sampler classes
  2. Manually tracking indices
  3. Reordering the underlying dataset
  4. Manually skipping samples after loading

zpqiu avatar Feb 24 '25 03:02 zpqiu

How about using StatefulDataloader instead of Dataloader? StatefulDataloader provides state_dict and load_state_dict methods that may support resuming the iterator position of mid-epoch checkpointing.

xffxff avatar Feb 24 '25 04:02 xffxff

How about using StatefulDataloader instead of Dataloader? StatefulDataloader provides state_dict and load_state_dict methods that may support resuming the iterator position of mid-epoch checkpointing.

Thank you. Will verl be modified to use StatefulDataLoader by default? Or do you recommend inheriting from Trainer and modifying it ourselves?

zpqiu avatar Feb 24 '25 07:02 zpqiu

I think it's reasonable. Please submit a PR. Thanks.

vermouth1992 avatar Feb 24 '25 08:02 vermouth1992