Add validation for LinearCrossEntropyLoss with custom_sharded_layers
Summary
Fixes #2856 - DTensor/torch.Tensor mixed type error in Llama4 LoRA fine-tuning
Problem
When using LoRA fine-tuning with LinearCrossEntropyLoss and custom_sharded_layers, users encounter a tensor type mismatch error:
RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
This happens because:
LinearCrossEntropyLossusesmodel.outputfor the final projection- LoRA configs typically set
custom_sharded_layers = ['tok_embeddings']without including'output' - FSDP wraps only the layers listed in
custom_sharded_layersas DTensors - This creates a mismatch when computing loss (DTensor hidden states × regular Tensor output weights)
Solution
Added validation that checks if LinearCrossEntropyLoss is used with custom_sharded_layers and ensures 'output' is included in the list. This provides a clear, actionable error message at setup time rather than a cryptic error during training.
Implementation
- Created shared validation module
recipes/validation.pyto avoid code duplication - Added validation to both
full_finetune_distributed.pyandlora_finetune_distributed.pyrecipes - Validation is called in
_setup_modelbefore FSDP wrapping occurs - Added comprehensive unit tests covering various edge cases
Testing
- Unit tests added in
tests/recipes/test_validation.py - Tests cover: missing output, correct config, None/empty layers, disabled parallelism, non-LinearCrossEntropyLoss
- No changes to existing functionality - only adds validation
Example Error Message
When misconfigured, users will now see: ValueError: When using LinearCrossEntropyLoss with custom_sharded_layers, 'output' must be included to ensure tensor compatibility. Example: custom_sharded_layers = ['tok_embeddings', 'output'].
This guides users to the correct configuration immediately.
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2900
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
@nathan-az Hey Nathan, apologies for confusion. I had meant to revise current PR but created a new branch without thinking. Wanted to give you a little context for this. After looking into it, here's what I found:
- The full fine-tuning config correctly includes
custom_sharded_layers: ['tok_embeddings', 'output'] - But, LoRA fine-tuning configs typically only include
['tok_embeddings'] - This makes sense for LoRA's use case, but causes issues with
LinearCrossEntropyLosswhich needs the output layer for its projection
Key changes from the previous attempt:
- No modifications to
cross_entropy_loss.py - Validation logic in a shared module to avoid duplication
- Error message that guides users to the fix
- Comprehensive unit tests
- Validation placed in
_setup_modelwhere both variables are in scope
I guess my "fix" here is just making sure users get a descriptive error message rather than runtime error during training. Particularly for LoRA users who might not realize they need to include 'output' when using LinearCrossEntropyLoss.
Thanks again for your review. Your notes were super helpful, and I learned some cool stuff working through this :) If you think further revision is needed just let me know.
Hey @jscaldwell55 - thanks again for your work here! 2 request if you have the time. Could you confirm you're using the latest (or near latest) nightlies? There were some changes made to the model root resharding logic during FSDP in response to some recent PyTorch changes, and this error is reminiscent of those.
In addition, could you provide an updated traceback/error logs showing what happens when you run your config on main, either here or in the original issue? The reason I ask is that a lot of the loss logic was updated after that bug was reported, where we stopped extracting the weight directly, so the old logs don't help much now. It would be good to see where the error actually reports now.
It's not immediately obvious to me why having tok_embeddings sharded separately (i.e. in custom_sharded_layers) would necessitate also including output, but having updated logs may make it clearer.
@nathan-az Got to work on this some this morning and completed the testing you requested. Here are findings:
Test Environment
- PyTorch: 2.6.0+cu124 (latest available in Colab)
- Torchtune: 0.6.1
- Model Used: meta-llama/Llama-2-7b-hf with LoRA
Test Results
Single-GPU Testing
✅ Training runs successfully without DTensor errors in single-GPU mode with the problematic configuration:
loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
custom_sharded_layers: ['tok_embeddings'] # Without 'output'
However, single-GPU doesn't trigger the FSDP DTensor wrapping that causes the issue.
Multi-GPU Evidence The original issue #2856 from June 29, 2025 shows the error still occurring with:
PyTorch 2.8.0.dev20250625 8xA100 GPUs Same configuration causing the exact DTensor/Tensor mismatch
Key Finding: The error explicitly occurs at cross_entropy_loss.py:71: pythonlogits = F.linear(hidden_chunk, weight) # DTensor × regular Tensor = Error
Analysis The issue manifests only in multi-GPU FSDP scenarios where:
- tok_embeddings in custom_sharded_layers → wrapped as DTensor
- output NOT in custom_sharded_layers → remains regular Tensor
- LinearCrossEntropyLoss uses output.weight for projection
- Mixed tensor types cause the runtime error
Note on Multi-GPU Testing
I wasn't able to generate updated error logs from multi-GPU testing due to resource constraints (only have access to single GPU via Colab). The single-GPU test doesn't trigger the FSDP DTensor wrapping, so it runs without errors.
The most recent multi-GPU error logs I could find are from issue #2856, which show the error occurring at the same location in the code (cross_entropy_loss.py:71). While these logs are from PyTorch 2.8.0.dev20250625, the error mechanism appears unchanged based on:
The code structure in current main still has the same F.linear call The FSDP wrapping logic with custom_sharded_layers works the same way The loss function still extracts model.output.weight for projection
For next steps, would you like me to:
Try to find a multi-GPU environment for definitive testing? Or proceed with the validation in this PR as a defensive measure given the evidence?
Please let me know what you're thinking re best path forward. Really appreciate your patience and feedback as I work through this :)
Hey mate, thanks for checking in on this. I have access to hardware again so ran some tests.
I ran two configs - the LLaMA 3.1 8B lora config, and the LLaMA 3.2 1B lora config. Both with custom_sharded_layers: ['tok_embeddings']
I was not able to replicate the issue with 8B, but I was with 1B. I believe the key difference in architecture between these two is that the 1B model uses tied weights between the token embeddings and final output projection. I think this is likely why we see issues when one is sharded and the other is not.
If this is correct, I think adding a layer of validation would be useful, but that it should factor this in.
@nathan-az Awesome, thanks for running those tests!
This definitely makes sense to me - models with tied embeddings (where tok_embeddings.weight is the same tensor as output.weight) would have this issue when only one is in custom_sharded_layers. The 1B model uses weight tying for efficiency, while the 8B model has separate weights.
I'll update the validation to be more precise:
- Check if the model uses tied weights (when
model.tok_embeddings.weight is model.output.weight) - Only enforce the validation for tied-weight models using
LinearCrossEntropyLoss - Update the error message to explain the tied weights issue
This way we can avoid unnecessary restrictions on models that don't have this architectural constraint.
Quick question: Should I also check for the reverse case (where output is in custom_sharded_layers but tok_embeddings isn't)? Or would FSDP handle that differently?
No worries! So - one thing you reported that I haven't been able to replicate is that custom sharding of both also throws the type mismatch error when the layers are tied for me.
I'm not an expert in how FSDP (and FSDP2) work, but I see that the TiedLinear class's forward directly accesses the weight attribute from the underlying embedding via return self.linear(x, self.tied_module.weight). My best theory is that since the weight is being accessed directly, the FSDP hook to unshard is not running (i.e. by the time the matmul is done, I think the weights should just be a Tensor, not DTensor).
@ebsmothers @joecummings sorry for the direct mention - if anything is immediately obvious to either of you (e.g. does my theory have merit, is there an obvious + is there an obvious/clean fix), that would be great. No pressure to look deeper, I just don't have that much experience with how FSDP does sharding.
IMO - if sharding tied weights doesn't currently work, it's not a huge deal and we can just add validation ot confirm. By default, the transformer layers are still sharded when FSDP is used (this can be confirmed in the sharding code), and because these weights are tied, they consume half as much memory as they would by default.
@nathan-az I'll hold off on implementing anything until we hear from the other reviewers. Happy to go with whatever makes the most sense to y'all.
Really appreciate you digging into this with me; I learn so much getting into these distributed training edge cases.