Memory efficient cross entropy with a linear layer fused
Implemented forward and backward of the following compute logics, which eliminated many intermediate storage tensors, and resulted in reduced peak memory usage.
Equivalent compute logic:
def run_torch_entropy(hidden: torch.Tensor,
weight: torch.Tensor,
labels: torch.Tensor) -> typing.List[torch.Tensor]:
logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size]
pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size]
entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens]
entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens]
entropy = entropy_a - entropy_b
logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1]
return logprobs, entropy
API
from verl.utils.kernel import linear_cross_entropy
hidden = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda")
weight = torch.randn(hidden_size, vocab_size, dtype=torch.bfloat16, device="cuda")
labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda")
loss, entropy = linear_cross_entropy(hidden, weight, labels, reduction="mean")
Storage and latency
Unit test
$ cd verl/
$ python3 tests/kernel/test_memory_efficient_entropy.py
NOTE
For compatibility, torch.library.triton_op was not applied to those APIs, so that torch.compile might not be able to be enabled on top of it.
Could you please perform formatting according to the readme?
Tested with torch and VeRL current implementation, the improvement is huge.
Currently integrated to dp_actor.py
The integration has OOM problem, with current fake-weight way. Will reconsider the fusion of linear layer with cross entropy.
A success of intergration is that the max_token_len can be significantly increased compared to not using this kernel
TP experiment result
Liger has a similar kernel called FusedLinearCrossEntropy
Liger has a similar kernel called
FusedLinearCrossEntropy
The kernel in liger can't satisfy the requirement as there are additional loss computation after the kernel, which liger kernel can't support
End2End results:
- https://api.wandb.ai/links/dongjbstrong-nvidia/sgwjz6eh
- https://api.wandb.ai/links/dongjbstrong-nvidia/76qa4pbx
There are multiple CI failures. Could you please fix them? Thanks.
Sorry for the close and open operations.
Use main branch to PR may be a dangerous operation for maintainers to cooperation and rebase (QaQ)
Next time will still use PR to others' repo.
Does this PR improve the GRPO loss computation in terms of peak memory? I've come across https://unsloth.ai/blog/grpo which describes how to implement GRPO in the chunked/fused style as well. So I wonder if Verl implements such technique as well
I was wondering whether the recent introduction of this feature might have contributed to the issue described below. https://github.com/volcengine/verl/issues/2547
Curious, why For compatibility, torch.library.triton_op was not applied to those APIs, so that torch.compile might not be able to be enabled on top of it.?
Wouldn't it be better to be also be able to use torch.compile on the whole model / loss?
I noticed some weird results after enabling kernel fusion as described in #2656. wondering if it's a bug or I didn't use it correctly. @Jianbing-D
@WindowsXp-Beta are problems with both torch and triton fused backend?
@WindowsXp-Beta are problems with both torch and triton fused backend?
Sorry for the late response. Was testing whether if it's caused by our internal model. Our current results show torch backend works normally but triton backend leads to reward collapse and entropy mismatch. We're also testing on Qwen2.5-VL to see if the problem still exists.
cc @eric-haibin-lin @vermouth1992
@vadimkantorov sorry for the late update. Spent some time setting up the environment to run a Qwen2.5-VL using the mainline code. We found log_probs and entropy calculated by fused_kernel and vanilla torch impl matched for Qwen2.5-VL. So looks like the problem is our side and we're still working on it.
Hi @vadimkantorov , after more tests we suspected the triton kernel may have bugs on certain hidden_states and weights values. Details see the latest comment in #2656. I wonder if you have ever seen similar thing / could reproduce this mismatch on your side.