data icon indicating copy to clipboard operation
data copied to clipboard

Async Checkpointing with DCP and Stateful Dataloader

Open conceptofmind opened this issue 5 months ago • 2 comments

Hello,

I was wondering what the recommended way was to use Async Checkpointing with the Stateful Dataloader?

Does this seem correct:

from torchdata.stateful_dataloader import StatefulDataLoader
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
import torch.distributed.checkpoint as dcp

class AsyncCheckpointer(Stateful):
    def __init__(self, model, optimizer, dataloader):
        self.model = model
        self.optimizer = optimizer
        self.dataloader = dataloader

    def state_dict(self):
        model_state_dict, optimizer_state_dict, dataloader_state_dict = get_state_dict(
            self.model, self.optimizer, self.dataloader
        )
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict,
            "dataloader": dataloader_state_dict
        }
    
    def load_state_dict(self, state_dict):
        set_state_dict(
            self.model,
            self.optimizer,
            self.dataloader,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"],
            dataloader_state_dict=state_dict["dataloader"]
        )

...
sampler = DistributedSampler(
    num_replicas=world_size, 
    rank=rank, 
    shuffle=True, 
)

trainloader = StatefulDataLoader(
    batch_size=64,
    sampler=sampler,
    num_workers=2, 
    collate_fn=data_collator
)
...

checkpoint_future = None

trainloader.load_state_dict(state_dict)
for step, batch in enumerate(trainloader):
    ...
    if checkpoint_future is not None:
        checkpoint_future.result()

    dataloader_state_dict = trainloader.state_dict()
    state_dict = { "app": AsyncCheckpointer(model, optimizer, dataloader_state_dict) }
    checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")

It is unclear to me from the documentation how these two should be combined.

Thank you,

Enrico

conceptofmind avatar Aug 06 '25 05:08 conceptofmind

Thanks for raising this. We don't have a recommended way for using DCP as it is fairly independent, but maybe we should.

@pradeepfn as the DCP dev, can you take a look at this?

divyanshk avatar Aug 06 '25 17:08 divyanshk

Thanks for raising this. We don't have a recommended way for using DCP as it is fairly independent, but maybe we should.

@pradeepfn as the DCP dev, can you take a look at this?

Hello,

Thank you for the response.

Would you recommend that it is better to decouple the Stateful Dataloader and the Async Checkpointing then? Save two separate instances of the statedict for both?

Thank you,

Enrico

conceptofmind avatar Aug 06 '25 17:08 conceptofmind