Calvin Pelletier
Calvin Pelletier
@joecummings ~~I'm guessing this is because the causal mask is created in `setup_caches()` [here](https://github.com/pytorch/torchtune/blob/main/torchtune/modules/transformer.py#L171), so without calling this function we're attending to all tokens, resulting in garbage outputs. Maybe we...
Hey @apachemycat, an option to save only the trainable weights for intermediate checkpoints is a great idea! We will add support for this soon. Regarding checkpointing every N steps, this...
Hi @vgoklani , we don't currently support this, but you could modify a recipe to call [torchao.float8.convert_to_float8_training](https://github.com/pytorch/ao/tree/main/torchao/float8) on your model at the end of [this function](https://github.com/pytorch/torchtune/blob/aa8f365f91a69aa36aaea14cf6f03ccd45310bb6/recipes/full_finetune_single_device.py#L410). However, I recommend using...
We would definitely appreciate a PR if full-finetuning in FP8 works out well for you all!
Yay step-based checkpointing! Some thoughts: 1. I second Felipe's comment about dropping support for epoch-based checkpointing. Our code will be cleaner and simpler if our whole ecosystem of checkpointing/validating/logging/etc is...
Hi @zhangtemplar , you're changing the generic `convert_weights` function. Qwen2.5 already has a specific convert weights function [here](https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_convert_weights.py) which handles the biases of the linear projections. In our Qwen2.5 configs,...
It looks like the rank zero device is getting stuck somewhere in this helper function: https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpoint_client.py#L355 Can you add some logs in there to see what's going on?
Yes that is abnormally high, I'm digging into this
@joecummings should Llama4 Scout use more CPU memory than LLama3.3 70B? When full finetuning Scout (`tune run --nproc_per_node 8 full_finetune_distributed --config llama4/scout_17B_16E_full` on 8xA100s), CPU mem is constant at 500...
Here's some additional info: https://github.com/pytorch/torchtune/issues/2111#issuecomment-2519077960