jonb377

Results 29 comments of jonb377

Also for some additional context, is this with a single host or are you using multiple hosts?

Do you know what epoch's checkpoint causes the issue? It sounds like pausing and resuming works, so when you restore from a high epoch there is no issue. I would...

> In your code, `model = model.to(device)` is called after `model.load_state_dict(checkpoint['model_state_dict'])` In my training code, I do not pass the model to device. Can you check? Is this erroneous? Oh...

Thanks @mfatih7, looking forward to the results! One other thing that could help with debugging is disabling async checkpointing by using `chkpt_mgr.save` instead of `chkpt_mgr.save_async`.

Hmm that's interesting... With `try/except` the error would be suppressed, but do you see the `print`ed message anywhere in the output? Also, was this run with `save` or `save_async`?

Ah good to hear! So it's not the initial checkpoint which causes the failure then. That points to the actual train loop as the culprit. A few more questions: -...

Hey @mfatih7, it sounds like your use case is to group the devices in pairs and train data parallel across those groups, correct? This is achievable with SPMD. As an...

It would also be worth checking out the FSDPv2 wrapper if you just want to train a bigger model using all devices: https://github.com/pytorch/xla/issues/6379

These warnings come from the upstream distributed checkpointing library and are OK to ignore for now. The deprecations will be addressed before the 2.3 release (the `save_state_dict` and `load_state_dict` functions...