torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Can we support outputting checkpoints directly in .pt format?

Open andrewor14 opened this issue 7 months ago • 8 comments

Today we need to do an extra conversion step according to this README: https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md

python -m torch.distributed.checkpoint.format_utils dcp_to_torch outputs/checkpoint/step-100 /tmp/checkpoint.pt

I think we should provide an option for users to specify which format to output their checkpoints instead, and call this function in torchtitan for users as part of outputting the checkpoint.


Bonus: This conversion step actually fails today if we used FP8 training. I had to manually add the following line to the dcp_to_torch function as a hack to get it to work:

torch.serialization.add_safe_globals([torchao.float8.fsdp_utils.WeightWithDynamicFloat8CastTensor])

It would be great if we can just either implicitly add the safe globals when we output the checkpoint in torchtitan, or simply remove this WeightWithDynamicFloat8CastTensor from the BC surface.

andrewor14 avatar May 09 '25 16:05 andrewor14

cc @vkuzo @fegin @wz337

andrewor14 avatar May 09 '25 16:05 andrewor14

It is reasonable to remove FP8 subclass from the checkpointing. I'll submit a PR for this. I may need some help from AO team to discuss how to remove FP8 subclass. cc., @vkuzo @danielvegamyhre

fegin avatar May 09 '25 16:05 fegin

Mind me asking why you would like .pt as the output format? E.g. is it because some downstream workload has to consume .pt format but not DCP?

tianyu-l avatar May 13 '25 05:05 tianyu-l

Hi @tianyu-l, yes exactly. Our use case is interop with torchtune, which accepts .pt or .safetensors. Actually between these two, .safetensors will be more useful, but .pt is also fine.

andrewor14 avatar May 13 '25 16:05 andrewor14

I believe we can support both formats. The issue is that how do we remove the FP8Tensor.

fegin avatar May 13 '25 17:05 fegin

sorry for late reply, catching up after my leave

how do we remove the FP8Tensor.

is there a hook of some sort that torchtitan calls when saving a state dict? The logic could go there.

vkuzo avatar May 23 '25 13:05 vkuzo

@vkuzo I can draft a PR to enable saving the state_dict to .pt. We don't need a hook for that. We just always convert the FP8 to the dtype users prefer.

fegin avatar May 23 '25 16:05 fegin

@vkuzo, @danielvegamyhre, @andrewor14 Please see the TODO in code of https://github.com/pytorch/torchtitan/pull/1219. We just need to convert the FP8 tensor to the regular tensor in the _export_weights().

fegin avatar May 23 '25 18:05 fegin