Async Checkpointing with DCP and Stateful Dataloader
Hello,
I was wondering what the recommended way was to use Async Checkpointing with the Stateful Dataloader?
Does this seem correct:
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
import torch.distributed.checkpoint as dcp
class AsyncCheckpointer(Stateful):
def __init__(self, model, optimizer, dataloader):
self.model = model
self.optimizer = optimizer
self.dataloader = dataloader
def state_dict(self):
model_state_dict, optimizer_state_dict, dataloader_state_dict = get_state_dict(
self.model, self.optimizer, self.dataloader
)
return {
"model": model_state_dict,
"optim": optimizer_state_dict,
"dataloader": dataloader_state_dict
}
def load_state_dict(self, state_dict):
set_state_dict(
self.model,
self.optimizer,
self.dataloader,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
dataloader_state_dict=state_dict["dataloader"]
)
...
sampler = DistributedSampler(
num_replicas=world_size,
rank=rank,
shuffle=True,
)
trainloader = StatefulDataLoader(
batch_size=64,
sampler=sampler,
num_workers=2,
collate_fn=data_collator
)
...
checkpoint_future = None
trainloader.load_state_dict(state_dict)
for step, batch in enumerate(trainloader):
...
if checkpoint_future is not None:
checkpoint_future.result()
dataloader_state_dict = trainloader.state_dict()
state_dict = { "app": AsyncCheckpointer(model, optimizer, dataloader_state_dict) }
checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
It is unclear to me from the documentation how these two should be combined.
Thank you,
Enrico
Thanks for raising this. We don't have a recommended way for using DCP as it is fairly independent, but maybe we should.
@pradeepfn as the DCP dev, can you take a look at this?
Thanks for raising this. We don't have a recommended way for using DCP as it is fairly independent, but maybe we should.
@pradeepfn as the DCP dev, can you take a look at this?
Hello,
Thank you for the response.
Would you recommend that it is better to decouple the Stateful Dataloader and the Async Checkpointing then? Save two separate instances of the statedict for both?
Thank you,
Enrico