Feature: linear cross entropy fusion
What does this PR do ?
This PR introduces an implementation that fuses Linear Layer of lm_head and Cross-Entropy, in order to avoid materializing the intermediate logits tensor, helping reducing memory footprint.
:warning: For major changes (either in lines of code or in its impact), please make sure to first share discuss a design-doc with the team.
Contribution process
flowchart LR
A[Pre-checks] --> B[PR Tests]
subgraph Code Review/Approval
C1[Expert Review] --> C2[Final Review]
end
B --> C1
C2 --> D[Merge]
Pre-checks
- [ ] I want this PR in a versioned release and have added the appropriate Milestone (e.g.,
Core 0.8) - [x] I have added relevant unit tests
- [x] I have added relevant functional tests
- [x] I have added proper typing to my code Typing guidelines
- [ ] I have added relevant documentation
- [x] I have run the autoformatter.sh on my PR
Code review
The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.
For MRs into `main` branch
(Step 1): Add PR label Expert Review
(Step 2): Collect the expert reviewers reviews
- Attach the
Expert Reviewlabel when your PR is ready for review. - GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.
:warning: Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.
(Step 3): Final Review
- Add
Final Reviewlabel - GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.
(Optional Step 4): Cherry-pick into release branch
If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either [email protected] or [email protected].
Merging your PR
Any member of core-adlr and core-nemo will be able to merge your PR.
This pull request requires additional validation before any workflows can run on NVIDIA's runners.
Pull request vetters can view their responsibilities here.
Contributors can view more details about this message here.
Details about this feature
Training LLM typically involves a two-stage pipeline at the output layer: hidden states are projected into vocabulary logits via a linear transformation (lm_head Layer), followed by Cross-Entropy loss computation against target tokens. While conceptually simple, such workflow incurs substantial overhead. The intermediate logits tensor, with dimension proportional to batch size, sequence length, and vocabulary size, must be fully materialized in GPU memory, even though only one target token per position is ultimately used. This leads to significant memory footprint and bandwidth consumption, limiting scalability and slowing training throughput. The following code snippet might better illustrate that workflow:
hidden_state = xxx # shape = [batch, seqlen, dim]
weight = lm_head.weight # shape = [vocabsize, dim]
labels = xxx # shape = [batch, seqlen]
logits = hidden_state @ weight.T
loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
On top of the local logit tensor, other techniques might need some other intermediate buffers for collecting full information across all GPUs. For example, the following snippet is a TP compatible layer, comprised of torch native ops:
tp_rank = xxx
tp_world_size = xxx
logits = hidden @ weight.T
whole_logits = torch.empty(
(logits.shape[0], logits.shape[-1] * tp_world_size),
dtype=logits.dtype,
device=logits.device,
)
whole_logits_ref = [
whole_logits[..., i * logits.shape[-1] : (i + 1) * logits.shape[-1]]
for i in range(tp_world_size)
]
dist.all_gather(whole_logits_ref, logits, group=tp_group)
logprobs = torch.nn.functional.cross_entropy(
whole_logits.view(-1, whole_logits.shape[-1]), labels.view(-1), reduction=reduction
)
By fusing Linear and Cross-Entropy into one single operation, this PR could help avoid materializing the intermediate logit tensor.
hidden_state = xxx # shape = [batch, seqlen, dim]
weight = lm_head.weight # shape = [vocabsize, dim]
labels = xxx # shape = [batch, seqlen]
loss = linear_cross_entropy(hidden_state, weight, labels, reduction="none")
which could help reduce 2bsv memory footprints AT LEAST.
- in the forward pass, no need to materializing
logittensor, whose shape is[batch, seqlen, vocabsize] - in the backward pass, no need to materializing
grad of logittensor, whose shape is also[batch, seqlen, vocabsize]
functionalities
def linear_cross_entropy(
hidden: torch.Tensor,
weight: torch.Tensor,
labels: torch.Tensor,
tp_group: typing.Optional[torch.distributed.ProcessGroup] = None,
reduction: str = "mean",
ignore_index: int = -100,
sequence_parallel: bool = False,
) -> torch.Tensor
- The input tensor is
BF16 or FP16format, and will conduct accumulation and other logics inFP32format, avoiding precision problem. - It supports
Data-Parallel,Tensor-Parallel along vocabsize, andSequence-Parallel along seqlen.- when
tp_group is Noneit works in DP mode - when
tp_group is not None, and sequence_parallel is False, it works in TP mode - when
tp_group is not None, and sequence_parallel is True, it works in SP mode
- when
- It supports specifying
ignore_idexaswhat native torch cross-entropy does. - It supports specifying
reductionmethod aswhat native torch cross-entropy does. - It is optimized for latest NVIDIA Blackwell GPUs.
Performance and Storage
In DP mode, this PR could lead to perf boost and storage reduction in the following config:
You may try the following steps to reproduce it:
# start a Megatron image on GB200
$ pip install nvidia-cutlass-dsl==4.2.1
$ pip install PyGithub
$ pytest -s -v tests/unit_tests/fusions/test_fused_linear_cross_entropy.py
$ torchrun --nproc_per_node=4 --nnodes=1 -m pytest -s -v tests/unit_tests/fusions/test_fused_linear_cross_entropy.py
We have already done some verification experiments on real LLM models to verify the accuracy / convergency and storage reduction while utilizing this feature. @shjwudp Please comment here.
Did you try out storing the logits in bf16? It could save lot of memory. Not sure if we need this fusion.
We have already done some verification experiments on real LLM models to verify the accuracy / convergency and storage reduction while utilizing this feature. @shjwudp Please comment here.
Convergence Test
We conducted tests on Llama3-8B and DSv3-Proxy (with the MTP layer). The convergence performance of the Linear Cross-Entropy Fusion method was comparable to the M-Core baseline.
Memory Saving Test
Compared with TE Cross-Entropy Fusion, the proposed Linear Cross-Entropy Fusion (this MR) achieved a memory reduction of approximately 3.9 GB.
Cross-Entropy Fusion:
Linear-Cross-Entropy Fusion (This MR):
Did you try out storing the logits in bf16? It could save lot of memory. Not sure if we need this fusion.
@kvareddy Yes, we tested bf16 activations (current TE cross-entropy fusion), and when the vocabulary size is large, the Linear-Cross-Entropy Fusion remains more memory-efficient compared to the CE fusion. Moreover, the linear-CE fusion preserves the precision of the logits better, as it avoids down-cast and up-cast during computation.
Did you try out storing the logits in bf16? It could save lot of memory. Not sure if we need this fusion.
@kvareddy Jack has given a good case from the viewpoint of memory saving and convergence. Theoretically, the logits in bf16 still could become a bottleneck in the fine-grained moe models, which have a high memory pressure.
With TP&SP&selective recomputation, the memory footprint of a layer is bsh*(34)/tp, while for the logit, the memory footprint is bsv/tp, which is roughly equivalent to half of the layer. For deepseek-v3, the best pp layout as of now is Et|(tt|)*30mL, the MTP layer contains a full model with 1 transformer layer, so the last pp stage could be the bottleneck of memory footprint, considering the large vocab size and optimizer states. Moreover, the memory footprint of a layer could be further reduced by fine-grained recomputation and offloading.
cc @yanring
Is this the dtype aware cut cross entropy kernel that helps with large vocabularies on already pretrained models or just the less specialized liger variant?
Linking this issue: https://github.com/NVIDIA/Megatron-LM/issues/1738
Can we make this feature optional? Also, can we move the kernels to TE and use them in the Megatron core?
@sanandaraj5597 Would you be willing to review this MR? Since you previous contributions to cross-entropy fusion, we would love to hear your feedback, thank you!
Can we make this feature optional? Also, can we move the kernels to TE and use them in the Megatron core?
Hi @kvareddy, this feature is currently optional and can be controlled via --cross-entropy-fusion-impl linear; it is not enable by default. The linear cross-entropy implementation is built using Triton and the CUTE DSL. Given that some of Megatron’s fused operators are also implemented in Triton, this approach isn’t really an exception. If the linear cross-entropy proves to be sufficiently general, I think it could be moved into TE in the future.
Has this been tested with large vocab sizes? We observed numerics issue with our port of a liger kernel into our codebase a few months ago when the vocab size of the model exceeded 250k (needed to upcast some accumulators if I recall). Actually worked fine at Llama vocab sizes, may want to double check: @shjwudp . Led to a subtle convergence issue IIRC.
Has this been tested with large vocab sizes? We observed numerics issue with our port of a liger kernel into our codebase a few months ago when the vocab size of the model exceeded 250k (needed to upcast some accumulators if I recall). Actually worked fine at Llama vocab sizes, may want to double check: @shjwudp . Led to a subtle convergence issue IIRC.
I indeed have not conducted experiments on the 256K vocabulary. As I understand it, the Linear-CE fusion algorithm is fully consistent with the standard Cross-Entropy formulation, so there shouldn’t be any numerical issues even with a larger vocabulary size. You could run a test with the 256K vocab to confirm — thanks for letting us know.
cc @Jianbing-D @kunlunl Please correct me if I’m mistaken.
Is this the dtype aware cut cross entropy kernel that helps with large vocabularies on already pretrained models or just the less specialized liger variant?
Just wanted to double check this as this could massively reduce memory for finetuning (and improve speed in that case for LLMs) while being bit wise numerically equivalent: https://github.com/apple/ml-cross-entropy
Is this the dtype aware cut cross entropy kernel that helps with large vocabularies on already pretrained models or just the less specialized liger variant?
Just wanted to double check this as this could massively reduce memory for finetuning (and improve speed in that case for LLMs) while being bit wise numerically equivalent: https://github.com/apple/ml-cross-entropy
Hi, this kernel follows this scheme to do its calculation:
logits = hidden_state.to(torch.float32) @ weight.T.to(torch.float32)
loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
The logits tensor will be eliminated since the GEMM process has been fused into the CrossEntropy. As you might've noticed, the GEMM process will produce Float32, and the subsequent CrossEntropy will also be conducted in Float32 format, to assure the numerical precision.
We can analyze this operation in two separate parts, forward pass and backward pass. And you can refer to previous experiments result to check the latency speedup and memory shrinkage. For implementation detail:
forward passwill totally eliminate thelogitstensor, which leads to significant memory reduction;backward passalso don't need a fulldlogitstensor, which leads to significant memory reduction as well.- both
logitsanddlogitstensor scale withBatchSize,SequenceLength, andVocabSize. The more memory savings, the larger those parameters increase.