torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

dcp.load fails on checkpoints prior to AdamW refactor

Open eminorhan opened this issue 10 months ago • 1 comments

I recently upgraded to a nightly pytorch build (2.7.0.dev20250221+rocm6.3) and noticed that the dcp checkpoints saved with my prior build could no longer be loaded successfully. The issue seems to be the same as the one discussed here: https://github.com/pytorch/pytorch/issues/146157

As I understand it, currently the recommended fix seems to be to pass a DefaultLoadPlanner with allow_partial_load=True to dcp.load. So, I was just wondering if there would be any objections to updating this:

dcp.load(
    states_to_load,
    checkpoint_id=self._create_checkpoint_id(step),
)

to something like:

dcp.load(
    states_to_load,
    planner=dcp.DefaultLoadPlanner(allow_partial_load=True),
    checkpoint_id=self._create_checkpoint_id(step),
)

for backwards compatibility.

eminorhan avatar Feb 25 '25 07:02 eminorhan

ye, we probably have to workaround the BC issue as it is caused by AdamW change.

fegin avatar Feb 25 '25 07:02 fegin