Megatron-LM icon indicating copy to clipboard operation
Megatron-LM copied to clipboard

Feature: linear cross entropy fusion

Open Jianbing-D opened this issue 1 month ago • 16 comments

What does this PR do ?

This PR introduces an implementation that fuses Linear Layer of lm_head and Cross-Entropy, in order to avoid materializing the intermediate logits tensor, helping reducing memory footprint.

:warning: For major changes (either in lines of code or in its impact), please make sure to first share discuss a design-doc with the team.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]

Pre-checks

  • [ ] I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • [x] I have added relevant unit tests
  • [x] I have added relevant functional tests
  • [x] I have added proper typing to my code Typing guidelines
  • [ ] I have added relevant documentation
  • [x] I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

:warning: Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either [email protected] or [email protected].

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

Jianbing-D avatar Nov 11 '25 08:11 Jianbing-D

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

copy-pr-bot[bot] avatar Nov 11 '25 08:11 copy-pr-bot[bot]

Details about this feature

Training LLM typically involves a two-stage pipeline at the output layer: hidden states are projected into vocabulary logits via a linear transformation (lm_head Layer), followed by Cross-Entropy loss computation against target tokens. While conceptually simple, such workflow incurs substantial overhead. The intermediate logits tensor, with dimension proportional to batch size, sequence length, and vocabulary size, must be fully materialized in GPU memory, even though only one target token per position is ultimately used. This leads to significant memory footprint and bandwidth consumption, limiting scalability and slowing training throughput. The following code snippet might better illustrate that workflow:

hidden_state = xxx # shape = [batch, seqlen, dim]
weight = lm_head.weight # shape = [vocabsize, dim]
labels = xxx # shape = [batch, seqlen]

logits = hidden_state @ weight.T
loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none")

On top of the local logit tensor, other techniques might need some other intermediate buffers for collecting full information across all GPUs. For example, the following snippet is a TP compatible layer, comprised of torch native ops:

tp_rank = xxx
tp_world_size = xxx

logits = hidden @ weight.T

whole_logits = torch.empty(
    (logits.shape[0], logits.shape[-1] * tp_world_size),
     dtype=logits.dtype,
     device=logits.device,
)
whole_logits_ref = [
     whole_logits[..., i * logits.shape[-1] : (i + 1) * logits.shape[-1]]
     for i in range(tp_world_size)
]
dist.all_gather(whole_logits_ref, logits, group=tp_group)

logprobs = torch.nn.functional.cross_entropy(
    whole_logits.view(-1, whole_logits.shape[-1]), labels.view(-1), reduction=reduction
)

By fusing Linear and Cross-Entropy into one single operation, this PR could help avoid materializing the intermediate logit tensor.

hidden_state = xxx # shape = [batch, seqlen, dim]
weight = lm_head.weight # shape = [vocabsize, dim]
labels = xxx # shape = [batch, seqlen]

loss = linear_cross_entropy(hidden_state, weight, labels, reduction="none")

which could help reduce 2bsv memory footprints AT LEAST.

  • in the forward pass, no need to materializing logit tensor, whose shape is [batch, seqlen, vocabsize]
  • in the backward pass, no need to materializing grad of logit tensor, whose shape is also [batch, seqlen, vocabsize]

functionalities

def linear_cross_entropy(
    hidden: torch.Tensor,
    weight: torch.Tensor,
    labels: torch.Tensor,
    tp_group: typing.Optional[torch.distributed.ProcessGroup] = None,
    reduction: str = "mean",
    ignore_index: int = -100,
    sequence_parallel: bool = False,
) -> torch.Tensor
  • The input tensor is BF16 or FP16 format, and will conduct accumulation and other logics in FP32 format, avoiding precision problem.
  • It supports Data-Parallel, Tensor-Parallel along vocabsize, and Sequence-Parallel along seqlen.
    1. when tp_group is None it works in DP mode
    2. when tp_group is not None, and sequence_parallel is False, it works in TP mode
    3. when tp_group is not None, and sequence_parallel is True, it works in SP mode
  • It supports specifying ignore_idex as what native torch cross-entropy does.
  • It supports specifying reduction method as what native torch cross-entropy does.
  • It is optimized for latest NVIDIA Blackwell GPUs.

Performance and Storage

In DP mode, this PR could lead to perf boost and storage reduction in the following config: image

You may try the following steps to reproduce it:

# start a Megatron image on GB200
$ pip install nvidia-cutlass-dsl==4.2.1
$ pip install PyGithub
$ pytest -s -v tests/unit_tests/fusions/test_fused_linear_cross_entropy.py
$ torchrun --nproc_per_node=4 --nnodes=1 -m pytest -s -v tests/unit_tests/fusions/test_fused_linear_cross_entropy.py

Jianbing-D avatar Nov 11 '25 09:11 Jianbing-D

We have already done some verification experiments on real LLM models to verify the accuracy / convergency and storage reduction while utilizing this feature. @shjwudp Please comment here.

Jianbing-D avatar Nov 11 '25 09:11 Jianbing-D

Did you try out storing the logits in bf16? It could save lot of memory. Not sure if we need this fusion.

kvareddy avatar Nov 11 '25 09:11 kvareddy

We have already done some verification experiments on real LLM models to verify the accuracy / convergency and storage reduction while utilizing this feature. @shjwudp Please comment here.

Convergence Test

We conducted tests on Llama3-8B and DSv3-Proxy (with the MTP layer). The convergence performance of the Linear Cross-Entropy Fusion method was comparable to the M-Core baseline.

screenshot-20251111-190423

Memory Saving Test

Compared with TE Cross-Entropy Fusion, the proposed Linear Cross-Entropy Fusion (this MR) achieved a memory reduction of approximately 3.9 GB.

Cross-Entropy Fusion: image

Linear-Cross-Entropy Fusion (This MR): image

shjwudp avatar Nov 11 '25 12:11 shjwudp

Did you try out storing the logits in bf16? It could save lot of memory. Not sure if we need this fusion.

@kvareddy Yes, we tested bf16 activations (current TE cross-entropy fusion), and when the vocabulary size is large, the Linear-Cross-Entropy Fusion remains more memory-efficient compared to the CE fusion. Moreover, the linear-CE fusion preserves the precision of the logits better, as it avoids down-cast and up-cast during computation.

shjwudp avatar Nov 11 '25 12:11 shjwudp

Did you try out storing the logits in bf16? It could save lot of memory. Not sure if we need this fusion.

@kvareddy Jack has given a good case from the viewpoint of memory saving and convergence. Theoretically, the logits in bf16 still could become a bottleneck in the fine-grained moe models, which have a high memory pressure.

With TP&SP&selective recomputation, the memory footprint of a layer is bsh*(34)/tp, while for the logit, the memory footprint is bsv/tp, which is roughly equivalent to half of the layer. For deepseek-v3, the best pp layout as of now is Et|(tt|)*30mL, the MTP layer contains a full model with 1 transformer layer, so the last pp stage could be the bottleneck of memory footprint, considering the large vocab size and optimizer states. Moreover, the memory footprint of a layer could be further reduced by fine-grained recomputation and offloading.

cc @yanring

lhb8125 avatar Nov 11 '25 13:11 lhb8125

Is this the dtype aware cut cross entropy kernel that helps with large vocabularies on already pretrained models or just the less specialized liger variant?

Skylion007 avatar Nov 11 '25 19:11 Skylion007

Linking this issue: https://github.com/NVIDIA/Megatron-LM/issues/1738

Skylion007 avatar Nov 11 '25 19:11 Skylion007

Can we make this feature optional? Also, can we move the kernels to TE and use them in the Megatron core?

kvareddy avatar Nov 12 '25 10:11 kvareddy

@sanandaraj5597 Would you be willing to review this MR? Since you previous contributions to cross-entropy fusion, we would love to hear your feedback, thank you!

shjwudp avatar Nov 13 '25 07:11 shjwudp

Can we make this feature optional? Also, can we move the kernels to TE and use them in the Megatron core?

Hi @kvareddy, this feature is currently optional and can be controlled via --cross-entropy-fusion-impl linear; it is not enable by default. The linear cross-entropy implementation is built using Triton and the CUTE DSL. Given that some of Megatron’s fused operators are also implemented in Triton, this approach isn’t really an exception. If the linear cross-entropy proves to be sufficiently general, I think it could be moved into TE in the future.

shjwudp avatar Nov 13 '25 07:11 shjwudp

Has this been tested with large vocab sizes? We observed numerics issue with our port of a liger kernel into our codebase a few months ago when the vocab size of the model exceeded 250k (needed to upcast some accumulators if I recall). Actually worked fine at Llama vocab sizes, may want to double check: @shjwudp . Led to a subtle convergence issue IIRC.

Skylion007 avatar Nov 15 '25 19:11 Skylion007

Has this been tested with large vocab sizes? We observed numerics issue with our port of a liger kernel into our codebase a few months ago when the vocab size of the model exceeded 250k (needed to upcast some accumulators if I recall). Actually worked fine at Llama vocab sizes, may want to double check: @shjwudp . Led to a subtle convergence issue IIRC.

I indeed have not conducted experiments on the 256K vocabulary. As I understand it, the Linear-CE fusion algorithm is fully consistent with the standard Cross-Entropy formulation, so there shouldn’t be any numerical issues even with a larger vocabulary size. You could run a test with the 256K vocab to confirm — thanks for letting us know.

cc @Jianbing-D @kunlunl Please correct me if I’m mistaken.

shjwudp avatar Nov 17 '25 08:11 shjwudp

Is this the dtype aware cut cross entropy kernel that helps with large vocabularies on already pretrained models or just the less specialized liger variant?

Just wanted to double check this as this could massively reduce memory for finetuning (and improve speed in that case for LLMs) while being bit wise numerically equivalent: https://github.com/apple/ml-cross-entropy

Skylion007 avatar Nov 22 '25 19:11 Skylion007

Is this the dtype aware cut cross entropy kernel that helps with large vocabularies on already pretrained models or just the less specialized liger variant?

Just wanted to double check this as this could massively reduce memory for finetuning (and improve speed in that case for LLMs) while being bit wise numerically equivalent: https://github.com/apple/ml-cross-entropy

Hi, this kernel follows this scheme to do its calculation:

logits = hidden_state.to(torch.float32) @ weight.T.to(torch.float32)
loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none")

The logits tensor will be eliminated since the GEMM process has been fused into the CrossEntropy. As you might've noticed, the GEMM process will produce Float32, and the subsequent CrossEntropy will also be conducted in Float32 format, to assure the numerical precision.

We can analyze this operation in two separate parts, forward pass and backward pass. And you can refer to previous experiments result to check the latency speedup and memory shrinkage. For implementation detail:

  • forward pass will totally eliminate the logits tensor, which leads to significant memory reduction;
  • backward pass also don't need a full dlogits tensor, which leads to significant memory reduction as well.
  • both logits and dlogits tensor scale with BatchSize, SequenceLength, and VocabSize. The more memory savings, the larger those parameters increase.

Jianbing-D avatar Nov 24 '25 01:11 Jianbing-D