[Dev] Feature: linear cross entropy fusion
What does this PR do ?
This PR introduces an implementation that fuses Linear Layer of lm_head and Cross-Entropy, in order to avoid materializing the intermediate logits tensor, helping reducing memory footprint.
PR to the main branch: https://github.com/NVIDIA/Megatron-LM/pull/2206
:warning: For major changes (either in lines of code or in its impact), please make sure to first share discuss a design-doc with the team.
Details about this feature
Training LLM typically involves a two-stage pipeline at the output layer: hidden states are projected into vocabulary logits via a linear transformation (lm_head Layer), followed by Cross-Entropy loss computation against target tokens. While conceptually simple, such workflow incurs substantial overhead. The intermediate logits tensor, with dimension proportional to batch size, sequence length, and vocabulary size, must be fully materialized in GPU memory, even though only one target token per position is ultimately used. This leads to significant memory footprint and bandwidth consumption, limiting scalability and slowing training throughput. The following code snippet might better illustrate that workflow:
hidden_state = xxx # shape = [batch, seqlen, dim]
weight = lm_head.weight # shape = [vocabsize, dim]
labels = xxx # shape = [batch, seqlen]
logits = hidden_state @ weight.T
loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
On top of the local logit tensor, other techniques might need some other intermediate buffers for collecting full information across all GPUs. For example, the following snippet is a TP compatible layer, comprised of torch native ops:
tp_rank = xxx
tp_world_size = xxx
logits = hidden @ weight.T
whole_logits = torch.empty(
(logits.shape[0], logits.shape[-1] * tp_world_size),
dtype=logits.dtype,
device=logits.device,
)
whole_logits_ref = [
whole_logits[..., i * logits.shape[-1] : (i + 1) * logits.shape[-1]]
for i in range(tp_world_size)
]
dist.all_gather(whole_logits_ref, logits, group=tp_group)
logprobs = torch.nn.functional.cross_entropy(
whole_logits.view(-1, whole_logits.shape[-1]), labels.view(-1), reduction=reduction
)
By fusing Linear and Cross-Entropy into one single operation, this PR could help avoid materializing the intermediate logit tensor.
hidden_state = xxx # shape = [batch, seqlen, dim]
weight = lm_head.weight # shape = [vocabsize, dim]
labels = xxx # shape = [batch, seqlen]
loss = linear_cross_entropy(hidden_state, weight, labels, reduction="none")
which could help reduce 2bsv memory footprints AT LEAST.
- in the forward pass, no need to materializing
logittensor, whose shape is[batch, seqlen, vocabsize] - in the backward pass, no need to materializing
grad of logittensor, whose shape is also[batch, seqlen, vocabsize]
Functionalities
def linear_cross_entropy(
hidden: torch.Tensor,
weight: torch.Tensor,
labels: torch.Tensor,
tp_group: typing.Optional[torch.distributed.ProcessGroup] = None,
reduction: str = "mean",
ignore_index: int = -100,
sequence_parallel: bool = False,
) -> torch.Tensor
- The input tensor is
BF16 or FP16format, and will conduct accumulation and other logics inFP32format, avoiding precision problem. - It supports
Data-Parallel,Tensor-Parallel along vocabsize, andSequence-Parallel along seqlen.- when
tp_group is Noneit works in DP mode - when
tp_group is not None, and sequence_parallel is False, it works in TP mode - when
tp_group is not None, and sequence_parallel is True, it works in SP mode
- when
- It supports specifying
ignore_idexaswhat native torch cross-entropy does. - It supports specifying
reductionmethod aswhat native torch cross-entropy does. - It is optimized for latest NVIDIA Blackwell GPUs.
Performance and Storage
In DP mode, this PR could lead to perf boost and storage reduction in the following config:
You may try the following steps to reproduce it:
# start a Megatron image on GB200
$ pip install nvidia-cutlass-dsl==4.2.1
$ pip install PyGithub
$ pytest -s -v tests/unit_tests/fusions/test_fused_linear_cross_entropy.py
$ torchrun --nproc_per_node=4 --nnodes=1 -m pytest -s -v tests/unit_tests/fusions/test_fused_linear_cross_entropy.py
Contribution process
flowchart LR
A[Pre-checks] --> B[PR Tests]
subgraph Code Review/Approval
C1[Expert Review] --> C2[Final Review]
end
B --> C1
C2 --> D[Merge]
Pre-checks
- [ ] I want this PR in a versioned release and have added the appropriate Milestone (e.g.,
Core 0.8) - [x] I have added relevant unit tests
- [x] I have added relevant functional tests
- [x] I have added proper typing to my code Typing guidelines
- [ ] I have added relevant documentation
- [x] I have run the autoformatter.sh on my PR
Code review
The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.
For MRs into `main` branch
(Step 1): Add PR label Expert Review
(Step 2): Collect the expert reviewers reviews
- Attach the
Expert Reviewlabel when your PR is ready for review. - GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.
:warning: Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.
(Step 3): Final Review
- Add
Final Reviewlabel - GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.
(Optional Step 4): Cherry-pick into release branch
If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either [email protected] or [email protected].
Merging your PR
Any member of core-adlr and core-nemo will be able to merge your PR.
This pull request requires additional validation before any workflows can run on NVIDIA's runners.
Pull request vetters can view their responsibilities here.
Contributors can view more details about this message here.
Details about this feature
Training LLM typically involves a two-stage pipeline at the output layer: hidden states are projected into vocabulary logits via a linear transformation (lm_head Layer), followed by Cross-Entropy loss computation against target tokens. While conceptually simple, such workflow incurs substantial overhead. The intermediate logits tensor, with dimension proportional to batch size, sequence length, and vocabulary size, must be fully materialized in GPU memory, even though only one target token per position is ultimately used. This leads to significant memory footprint and bandwidth consumption, limiting scalability and slowing training throughput. The following code snippet might better illustrate that workflow:
hidden_state = xxx # shape = [batch, seqlen, dim]
weight = lm_head.weight # shape = [vocabsize, dim]
labels = xxx # shape = [batch, seqlen]
logits = hidden_state @ weight.T
loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
On top of the local logit tensor, other techniques might need some other intermediate buffers for collecting full information across all GPUs. For example, the following snippet is a TP compatible layer, comprised of torch native ops:
tp_rank = xxx
tp_world_size = xxx
logits = hidden @ weight.T
whole_logits = torch.empty(
(logits.shape[0], logits.shape[-1] * tp_world_size),
dtype=logits.dtype,
device=logits.device,
)
whole_logits_ref = [
whole_logits[..., i * logits.shape[-1] : (i + 1) * logits.shape[-1]]
for i in range(tp_world_size)
]
dist.all_gather(whole_logits_ref, logits, group=tp_group)
logprobs = torch.nn.functional.cross_entropy(
whole_logits.view(-1, whole_logits.shape[-1]), labels.view(-1), reduction=reduction
)
By fusing Linear and Cross-Entropy into one single operation, this PR could help avoid materializing the intermediate logit tensor.
hidden_state = xxx # shape = [batch, seqlen, dim]
weight = lm_head.weight # shape = [vocabsize, dim]
labels = xxx # shape = [batch, seqlen]
loss = linear_cross_entropy(hidden_state, weight, labels, reduction="none")
which could help reduce 2bsv memory footprints AT LEAST.
- in the forward pass, no need to materializing
logittensor, whose shape is[batch, seqlen, vocabsize] - in the backward pass, no need to materializing
grad of logittensor, whose shape is also[batch, seqlen, vocabsize]
Functionalities
def linear_cross_entropy(
hidden: torch.Tensor,
weight: torch.Tensor,
labels: torch.Tensor,
tp_group: typing.Optional[torch.distributed.ProcessGroup] = None,
reduction: str = "mean",
ignore_index: int = -100,
sequence_parallel: bool = False,
) -> torch.Tensor
- The input tensor is
BF16 or FP16format, and will conduct accumulation and other logics inFP32format, avoiding precision problem. - It supports
Data-Parallel,Tensor-Parallel along vocabsize, andSequence-Parallel along seqlen.- when
tp_group is Noneit works in DP mode - when
tp_group is not None, and sequence_parallel is False, it works in TP mode - when
tp_group is not None, and sequence_parallel is True, it works in SP mode
- when
- It supports specifying
ignore_idexaswhat native torch cross-entropy does. - It supports specifying
reductionmethod aswhat native torch cross-entropy does. - It is optimized for latest NVIDIA Blackwell GPUs.
Performance and Storage
In DP mode, this PR could lead to perf boost and storage reduction in the following config:
You may try the following steps to reproduce it:
# start a Megatron image on GB200
$ pip install nvidia-cutlass-dsl==4.2.1
$ pip install PyGithub
$ pytest -s -v tests/unit_tests/fusions/test_fused_linear_cross_entropy.py
$ torchrun --nproc_per_node=4 --nnodes=1 -m pytest -s -v tests/unit_tests/fusions/test_fused_linear_cross_entropy.py
For convergency test, please refer to: https://github.com/NVIDIA/Megatron-LM/pull/2206#issuecomment-3516697601
Linking: https://github.com/NVIDIA/Megatron-LM/pull/2206
@lhb8125 Hey Hongbin, could you help run the functional test on GitLab for this PR?
/ok to test 78c827e
/ok to test 011947d
/ok to test 1b603b9
/ok to test fb2ee78
/ok to test 4dc347c
/ok to test ae4d83a
Thank you for your contribution!
NVIDIA Megatron-LM is currently transitioning to development on Github. We will aim to review your PR after we complete our transition and stabilize our Github development process.
Thank you for your understanding.