[zero_to_fp32] fix shared param recovery
Fixes: https://github.com/microsoft/DeepSpeed/pull/3033
The algorithm to figure out shared params added in https://github.com/microsoft/DeepSpeed/pull/3033 doesn't work as all tensors are placeholders with size 0 and their data_ptr() is always 0 and therefore equality comparison is bogus, which leads to an invalid reconsolidated state_dict, where a shared param ends up sharing with a random other param. It may work accidentally if the shared param appears to be first in the list.
This PR pre-figures out the shared params based on the same logic as in _zero3_consolidated_16bit_state_dict where we use ds_id to identify shared params:
https://github.com/microsoft/DeepSpeed/blob/a094c9763de8d42107cbffd0bb9abb8056aa3c60/deepspeed/runtime/engine.py#L3225-L3228
This data is then stored in the checkpoint and is easy to use in zero_to_fp32.py
Please please please always add tests when adding new features.
POSSIBLE TODO:
- could there be shared buffers? if so we probably need to extend it to handle buffers as well. I skipped it since we aren't handling shared buffers in
_zero3_consolidated_16bit_state_dict, so this is symmetrical.
@tjruwase
@ShijieZZZZ and @mayank31398 - please check that this way still works for you. Without tests there is no knowing. Thank you!
@tjruwase, thinking aloud here - also what's the point of saving tensor placeholders, if they are stripped of their ds metadata - this is again a source of confusion and here was the cause of the bug. If something isn't usable let's not put it into the checkpoint.
Thanks @stas00 Ill test this out sometime today :)