torchtitan
torchtitan copied to clipboard
dcp.load fails on checkpoints prior to AdamW refactor
I recently upgraded to a nightly pytorch build (2.7.0.dev20250221+rocm6.3) and noticed that the dcp checkpoints saved with my prior build could no longer be loaded successfully. The issue seems to be the same as the one discussed here: https://github.com/pytorch/pytorch/issues/146157
As I understand it, currently the recommended fix seems to be to pass a DefaultLoadPlanner with allow_partial_load=True to dcp.load. So, I was just wondering if there would be any objections to updating this:
dcp.load(
states_to_load,
checkpoint_id=self._create_checkpoint_id(step),
)
to something like:
dcp.load(
states_to_load,
planner=dcp.DefaultLoadPlanner(allow_partial_load=True),
checkpoint_id=self._create_checkpoint_id(step),
)
for backwards compatibility.
ye, we probably have to workaround the BC issue as it is caused by AdamW change.