Calvin Pelletier

Results 10 comments of Calvin Pelletier

@joecummings ~~I'm guessing this is because the causal mask is created in `setup_caches()` [here](https://github.com/pytorch/torchtune/blob/main/torchtune/modules/transformer.py#L171), so without calling this function we're attending to all tokens, resulting in garbage outputs. Maybe we...

Hey @apachemycat, an option to save only the trainable weights for intermediate checkpoints is a great idea! We will add support for this soon. Regarding checkpointing every N steps, this...

Hi @vgoklani , we don't currently support this, but you could modify a recipe to call [torchao.float8.convert_to_float8_training](https://github.com/pytorch/ao/tree/main/torchao/float8) on your model at the end of [this function](https://github.com/pytorch/torchtune/blob/aa8f365f91a69aa36aaea14cf6f03ccd45310bb6/recipes/full_finetune_single_device.py#L410). However, I recommend using...

We would definitely appreciate a PR if full-finetuning in FP8 works out well for you all!

Yay step-based checkpointing! Some thoughts: 1. I second Felipe's comment about dropping support for epoch-based checkpointing. Our code will be cleaner and simpler if our whole ecosystem of checkpointing/validating/logging/etc is...

Hi @zhangtemplar , you're changing the generic `convert_weights` function. Qwen2.5 already has a specific convert weights function [here](https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_convert_weights.py) which handles the biases of the linear projections. In our Qwen2.5 configs,...

It looks like the rank zero device is getting stuck somewhere in this helper function: https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpoint_client.py#L355 Can you add some logs in there to see what's going on?

Yes that is abnormally high, I'm digging into this

@joecummings should Llama4 Scout use more CPU memory than LLama3.3 70B? When full finetuning Scout (`tune run --nproc_per_node 8 full_finetune_distributed --config llama4/scout_17B_16E_full` on 8xA100s), CPU mem is constant at 500...

Here's some additional info: https://github.com/pytorch/torchtune/issues/2111#issuecomment-2519077960