verl icon indicating copy to clipboard operation
verl copied to clipboard

Memory efficient cross entropy with a linear layer fused

Open Jianbing-D opened this issue 10 months ago • 1 comments

Implemented forward and backward of the following compute logics, which eliminated many intermediate storage tensors, and resulted in reduced peak memory usage.

Equivalent compute logic:

def run_torch_entropy(hidden: torch.Tensor,
                    weight: torch.Tensor,
                    labels: torch.Tensor) -> typing.List[torch.Tensor]:
    logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size]
    pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size]
    entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens]
    entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens]
    entropy = entropy_a - entropy_b
    logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1]
    return logprobs, entropy

API

from verl.utils.kernel import linear_cross_entropy

hidden = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda")
weight = torch.randn(hidden_size, vocab_size, dtype=torch.bfloat16, device="cuda")
labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda")

loss, entropy = linear_cross_entropy(hidden, weight, labels, reduction="mean")

Storage and latency

image

Unit test

$ cd verl/
$ python3 tests/kernel/test_memory_efficient_entropy.py

NOTE

For compatibility, torch.library.triton_op was not applied to those APIs, so that torch.compile might not be able to be enabled on top of it.

Jianbing-D avatar Mar 04 '25 06:03 Jianbing-D

Could you please perform formatting according to the readme?

vermouth1992 avatar Mar 10 '25 01:03 vermouth1992

Tested with torch and VeRL current implementation, the improvement is huge.

image

Currently integrated to dp_actor.py

ETOgaosion avatar Mar 20 '25 08:03 ETOgaosion

The integration has OOM problem, with current fake-weight way. Will reconsider the fusion of linear layer with cross entropy.

ETOgaosion avatar Mar 20 '25 12:03 ETOgaosion

A success of intergration is that the max_token_len can be significantly increased compared to not using this kernel

vermouth1992 avatar Mar 20 '25 13:03 vermouth1992

TP experiment result

image

Jianbing-D avatar Mar 25 '25 09:03 Jianbing-D

Liger has a similar kernel called FusedLinearCrossEntropy

gameofdimension avatar Apr 08 '25 10:04 gameofdimension

Liger has a similar kernel called FusedLinearCrossEntropy

The kernel in liger can't satisfy the requirement as there are additional loss computation after the kernel, which liger kernel can't support

vermouth1992 avatar Apr 08 '25 11:04 vermouth1992

End2End results:

  • https://api.wandb.ai/links/dongjbstrong-nvidia/sgwjz6eh
  • https://api.wandb.ai/links/dongjbstrong-nvidia/76qa4pbx

image image

Jianbing-D avatar Jun 06 '25 07:06 Jianbing-D

There are multiple CI failures. Could you please fix them? Thanks.

vermouth1992 avatar Jun 07 '25 02:06 vermouth1992

Sorry for the close and open operations.

Use main branch to PR may be a dangerous operation for maintainers to cooperation and rebase (QaQ)

Next time will still use PR to others' repo.

ETOgaosion avatar Jun 08 '25 17:06 ETOgaosion

Does this PR improve the GRPO loss computation in terms of peak memory? I've come across https://unsloth.ai/blog/grpo which describes how to implement GRPO in the chunked/fused style as well. So I wonder if Verl implements such technique as well

vadimkantorov avatar Jul 08 '25 19:07 vadimkantorov

I was wondering whether the recent introduction of this feature might have contributed to the issue described below. https://github.com/volcengine/verl/issues/2547

hljjjmssyh avatar Jul 17 '25 05:07 hljjjmssyh

Curious, why For compatibility, torch.library.triton_op was not applied to those APIs, so that torch.compile might not be able to be enabled on top of it.?

Wouldn't it be better to be also be able to use torch.compile on the whole model / loss?

vadimkantorov avatar Jul 17 '25 23:07 vadimkantorov

I noticed some weird results after enabling kernel fusion as described in #2656. wondering if it's a bug or I didn't use it correctly. @Jianbing-D

WindowsXp-Beta avatar Jul 21 '25 08:07 WindowsXp-Beta

@WindowsXp-Beta are problems with both torch and triton fused backend?

vadimkantorov avatar Jul 21 '25 14:07 vadimkantorov

@WindowsXp-Beta are problems with both torch and triton fused backend?

Sorry for the late response. Was testing whether if it's caused by our internal model. Our current results show torch backend works normally but triton backend leads to reward collapse and entropy mismatch. We're also testing on Qwen2.5-VL to see if the problem still exists.

WindowsXp-Beta avatar Jul 24 '25 01:07 WindowsXp-Beta

cc @eric-haibin-lin @vermouth1992

vadimkantorov avatar Jul 25 '25 09:07 vadimkantorov

@vadimkantorov sorry for the late update. Spent some time setting up the environment to run a Qwen2.5-VL using the mainline code. We found log_probs and entropy calculated by fused_kernel and vanilla torch impl matched for Qwen2.5-VL. So looks like the problem is our side and we're still working on it.

WindowsXp-Beta avatar Jul 25 '25 09:07 WindowsXp-Beta

Hi @vadimkantorov , after more tests we suspected the triton kernel may have bugs on certain hidden_states and weights values. Details see the latest comment in #2656. I wonder if you have ever seen similar thing / could reproduce this mismatch on your side.

WindowsXp-Beta avatar Jul 25 '25 22:07 WindowsXp-Beta