torchtitan
torchtitan copied to clipboard
[WIP] Implement the feature to save unsharded weights at the last step
Summary: Several users have been asking this feature: https://github.com/pytorch/torchtitan/issues/1177
TODO: Remove fp8 subclass tensor TODO: Support HF format
Test Plan:
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --parallelism.tensor_parallel_degree 4 --parallelism.enable_async_tensor_parallel --checkpoint.model_weights_only --checkpoint.unshard_weights --checkpoint.export_dtype="bfloat16" --training.steps=10 --checkpoint.enable_checkpoint
TODO: Support HF format
This would require HF dependency in torchtitan core, right?
This would require HF dependency in torchtitan core, right?
Yes, unfortunately, that is the case. PyTorch also optionally depends on HF due to DCP. We can use the same logic -- erroring out when users specify HF format but HF package is not installed.