Can we support outputting checkpoints directly in .pt format?
Today we need to do an extra conversion step according to this README: https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md
python -m torch.distributed.checkpoint.format_utils dcp_to_torch outputs/checkpoint/step-100 /tmp/checkpoint.pt
I think we should provide an option for users to specify which format to output their checkpoints instead, and call this function in torchtitan for users as part of outputting the checkpoint.
Bonus: This conversion step actually fails today if we used FP8 training. I had to manually add the following line to the dcp_to_torch function as a hack to get it to work:
torch.serialization.add_safe_globals([torchao.float8.fsdp_utils.WeightWithDynamicFloat8CastTensor])
It would be great if we can just either implicitly add the safe globals when we output the checkpoint in torchtitan, or simply remove this WeightWithDynamicFloat8CastTensor from the BC surface.
cc @vkuzo @fegin @wz337
It is reasonable to remove FP8 subclass from the checkpointing. I'll submit a PR for this. I may need some help from AO team to discuss how to remove FP8 subclass. cc., @vkuzo @danielvegamyhre
Mind me asking why you would like .pt as the output format? E.g. is it because some downstream workload has to consume .pt format but not DCP?
Hi @tianyu-l, yes exactly. Our use case is interop with torchtune, which accepts .pt or .safetensors. Actually between these two, .safetensors will be more useful, but .pt is also fine.
I believe we can support both formats. The issue is that how do we remove the FP8Tensor.
sorry for late reply, catching up after my leave
how do we remove the FP8Tensor.
is there a hook of some sort that torchtitan calls when saving a state dict? The logic could go there.
@vkuzo I can draft a PR to enable saving the state_dict to .pt. We don't need a hook for that. We just always convert the FP8 to the dtype users prefer.
@vkuzo, @danielvegamyhre, @andrewor14 Please see the TODO in code of https://github.com/pytorch/torchtitan/pull/1219. We just need to convert the FP8 tensor to the regular tensor in the _export_weights().