Fused Linear and Cross-Entropy Loss `torch.nn.functional.linear_cross_entropy`
🚀 The feature, motivation and pitch
It'd be great to have a fused linear and cross-entropy function in PyTorch, for example, torch.nn.functional.linear_cross_entropy. This function acts as a fused linear projection followed by a cross-entropy loss, e.g.
def linear_cross_entropy(linear_weights, input, labels):
logits = F.linear(input, linear_weights)
return F.cross_entropy(logits, labels)
Compared to naive implementation, this fused function does not materialize the intermediate logits, saving a lot of VRAM.
This function is quite common in models, such as LLMs, classification, and so on. When the batch size or number of classes is large, the intermediate logits would take up the majority of VRAM.
For example, the latest LLMs, such as Llama 3 and Gemma, adopt extremely large vocabularies (128-256K). Thus, the size of logits can become very large, consuming a significant proportion of VRAM. The calculation below shows the VRAM size for Llama 3 8B with a batch size of 81920 tokens:
Logits size: 128256 (vocabulary size) * 81920 (batch size) * 2 bytes (bf16) = 19.57GiB
Hidden state size (checkpointed): 4096 (hidden size) * 81920 (batch size) * 32 (layers) * 2 bytes (bf16) = 20GiB
Therefore, a fused linear and cross-entropy loss function that does not require materializing full logits may reduce VRAM consumption by half.
Related: the same feature request for FlashAttention: https://github.com/Dao-AILab/flash-attention/issues/922
Alternatives
No response
Additional context
No response
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki
cc: @Chillee would be curious to hear your thoughts
Funny thing, I have been working on exactly that very recently (for Graphcore IPU, but nothing specific about the hardware in the technique). It's definitely possible to serialize & fuse the linear + softmax + cross entropy combination, the only trade-off being the recomputation necessary for the backward pass. Happy to help!
FWIW, I derived some formulas & wrote a (very rough) WIP prototype.
import torch
import triton
import triton.language as tl
@triton.jit
def fwd_kernel(
x_ND_ptr,
w_DV_ptr,
c_N_ptr,
output_N_ptr,
l_N_ptr,
N, D, V,
stride_xn, stride_xd,
stride_wd, stride_wv,
# Meta-parameters
BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_V: tl.constexpr,
):
# TODO: more parallelism, e.g. tiled softmax
# w/ parallelization across tiles (intra-tile computation uses online softmax)
# TODO: mask
# only parallelize along the N dimension
# i is the same as n
i = tl.program_id(axis=0)
offs_n_bN = i * BLOCK_N + tl.arange(0, BLOCK_N)
c_i_bN = tl.load(c_N_ptr + offs_n_bN)
output_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32)
# statistics for online softmax
m_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) - float('inf')
l_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) + 1.0
for start_v in range(0, V, BLOCK_V):
start_v = tl.multiple_of(start_v, BLOCK_V)
offs_v_bN = start_v + tl.arange(0, BLOCK_V)
# TODO: mask
x_ND_block_ptr = tl.make_block_ptr(
base=x_ND_ptr,
shape=(N, D),
strides=(stride_xn, stride_xd),
offsets=(i * BLOCK_N, 0),
block_shape=(BLOCK_N, BLOCK_D),
order=(1, 0),
)
w_DV_block_ptr = tl.make_block_ptr(
base=w_DV_ptr,
shape=(D, V),
strides=(stride_wd, stride_wv),
offsets=(0, start_v),
block_shape=(BLOCK_D, BLOCK_V),
order=(1, 0),
)
xw_bNbV = tl.zeros([BLOCK_N, BLOCK_V], dtype=tl.float32)
for start_d in range(0, D, BLOCK_D):
start_d = tl.multiple_of(start_d, BLOCK_D)
# TODO: mask
# TODO: x load can be reduced?
x_bNbD = tl.load(x_ND_block_ptr)
w_bDbV = tl.load(w_DV_block_ptr)
xw_bNbV = tl.dot(x_bNbD, w_bDbV, xw_bNbV)
x_ND_block_ptr = tl.advance(x_ND_block_ptr, (0, BLOCK_D))
w_DV_block_ptr = tl.advance(w_DV_block_ptr, (BLOCK_D, 0))
# i for N
# j for V
m_ij_bN = tl.maximum(m_i_bN, tl.max(xw_bNbV, axis=1))
p_ij_bNbV = tl.exp(xw_bNbV - m_ij_bN[:, None])
l_ij_bN = tl.sum(p_ij_bNbV, axis=1)
# update m_i and l_i
alpha_bN = tl.exp(m_i_bN - m_ij_bN)
l_i_bN = l_i_bN * alpha_bN + l_ij_bN
m_i_bN = m_ij_bN
# update output
p_ic_bN = tl.sum(tl.where(c_i_bN[:, None] == offs_v_bN[None, :], p_ij_bNbV, 0.0), axis=1)
output_i_bN = output_i_bN * alpha_bN + p_ic_bN
output_i_bN = tl.log(output_i_bN) - tl.log(l_i_bN)
tl.store(output_N_ptr + offs_n_bN, output_i_bN)
tl.store(l_N_ptr + offs_n_bN, l_i_bN)
# output_N[n] = log_softmax(x_ND @ w_DV, dim=1)[n, c_N[n]]
class LMHeadThenLogSoftmaxThenGather(torch.autograd.Function):
@staticmethod
def forward(ctx, x_ND, w_DV, c_N):
# TODO
l_N = None
ctx.save_for_backward(w_DV, x_ND, c_N, l_N)
@staticmethod
def backward(ctx, g_NV):
# TODO
w_DV, x_ND, c_N, l_N = ctx.saved_tensors
D, V, N = 10, 20, 30
w_DV = torch.randn(D, V, requires_grad=True)
x_NV = torch.randn(N, D, requires_grad=True)
c_N = torch.randint(V, (N,), dtype=torch.int64)
cc @eellison from triage meeting: is this something we can pattern match for in compile/is it already handled by the scheduler in compile?
Update: forward finished
import torch
import triton
import triton.language as tl
@triton.jit
def fwd_kernel(
x_ND_ptr,
w_DV_ptr,
c_N_ptr,
output_N_ptr,
l_N_ptr,
N, D, V,
stride_xn, stride_xd,
stride_wd, stride_wv,
# Meta-parameters
BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_V: tl.constexpr,
):
# TODO: more parallelism, e.g. tiled softmax
# w/ parallelization across tiles (intra-tile computation uses online softmax)
# TODO: mask
# only parallelize along the N dimension
# i is the same as n
i = tl.program_id(axis=0)
offs_n_bN = i * BLOCK_N + tl.arange(0, BLOCK_N)
c_i_bN = tl.load(c_N_ptr + offs_n_bN)
output_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32)
# statistics for online softmax
m_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) - float('inf')
l_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) + 1.0
for start_v in range(0, V, BLOCK_V):
start_v = tl.multiple_of(start_v, BLOCK_V)
offs_v_bN = start_v + tl.arange(0, BLOCK_V)
# TODO: mask
x_ND_block_ptr = tl.make_block_ptr(
base=x_ND_ptr,
shape=(N, D),
strides=(stride_xn, stride_xd),
offsets=(i * BLOCK_N, 0),
block_shape=(BLOCK_N, BLOCK_D),
order=(1, 0),
)
w_DV_block_ptr = tl.make_block_ptr(
base=w_DV_ptr,
shape=(D, V),
strides=(stride_wd, stride_wv),
offsets=(0, start_v),
block_shape=(BLOCK_D, BLOCK_V),
order=(1, 0),
)
xw_bNbV = tl.zeros([BLOCK_N, BLOCK_V], dtype=tl.float32)
for start_d in range(0, D, BLOCK_D):
start_d = tl.multiple_of(start_d, BLOCK_D)
# TODO: mask
# TODO: x load can be reduced?
x_bNbD = tl.load(x_ND_block_ptr)
w_bDbV = tl.load(w_DV_block_ptr)
xw_bNbV = tl.dot(x_bNbD, w_bDbV, xw_bNbV)
x_ND_block_ptr = tl.advance(x_ND_block_ptr, (0, BLOCK_D))
w_DV_block_ptr = tl.advance(w_DV_block_ptr, (BLOCK_D, 0))
# i for N
# j for V
m_ij_bN = tl.maximum(m_i_bN, tl.max(xw_bNbV, axis=1))
p_ij_bNbV = tl.exp(xw_bNbV - m_ij_bN[:, None])
l_ij_bN = tl.sum(p_ij_bNbV, axis=1)
# update m_i and l_i
alpha_bN = tl.exp(m_i_bN - m_ij_bN)
l_i_bN = l_i_bN * alpha_bN + l_ij_bN
m_i_bN = m_ij_bN
# update output
p_ic_bN = tl.sum(tl.where(c_i_bN[:, None] == offs_v_bN[None, :], p_ij_bNbV, 0.0), axis=1)
output_i_bN = output_i_bN * alpha_bN + p_ic_bN
output_i_bN = tl.log(output_i_bN) - tl.log(l_i_bN)
tl.store(output_N_ptr + offs_n_bN, output_i_bN)
tl.store(l_N_ptr + offs_n_bN, l_i_bN)
# output_N[n] = log_softmax(x_ND @ w_DV, dim=1)[n, c_N[n]]
class LMHeadThenLogSoftmaxThenGather(torch.autograd.Function):
@staticmethod
def forward(ctx, x_ND: torch.Tensor, w_DV: torch.Tensor, c_N: torch.Tensor):
# TODO
N, D = x_ND.shape
Dw, V = w_DV.shape
Nc, = c_N.shape
assert D == Dw and N == Nc
output_N = x_ND.new_empty(N)
l_N = x_ND.new_empty(N)
BLOCK_N = 32
BLOCK_D = 32
BLOCK_V = 32
grid = (triton.cdiv(N, BLOCK_N),)
fwd_kernel[grid](
x_ND, w_DV, c_N,
output_N, l_N,
N, D, V,
x_ND.stride(0), x_ND.stride(1),
w_DV.stride(0), w_DV.stride(1),
BLOCK_N, BLOCK_D, BLOCK_V,
)
ctx.save_for_backward(w_DV, x_ND, c_N, l_N)
return output_N
@staticmethod
def backward(ctx, g_NV):
# TODO
w_DV, x_ND, c_N, l_N = ctx.saved_tensors
D, V, N = 4096, 131072, 65536
w_DV = torch.empty(D, V, dtype=torch.bfloat16, device='cuda').normal_(std= D ** -0.5)
x_NV = torch.empty(N, D, dtype=torch.bfloat16, device='cuda').normal_()
c_N = torch.empty(N, dtype=torch.int64, device='cuda').random_(V)
output_N: torch.Tensor = LMHeadThenLogSoftmaxThenGather.apply(x_NV, w_DV, c_N)
ref = torch.nn.functional.log_softmax(x_NV @ w_DV, dim=1)[torch.arange(N), c_N].double()
err = abs(output_N.double() - ref) / abs(ref)
print(torch.mean(err).item(), torch.quantile(err, 0.95).item(), torch.quantile(err, 0.99).item())
Output: 9.208518284318959e-05 0.0 0.004878048780487805
hi @imoneoi thank you for the raising the issue, I'm suffering from the same issue and would be very happy if this feature is added to torch ! i have a question about the batch size you mentioned, 81920. is this equivalent to the number of valid tokens used for computing loss ? ( B*T where B is batch size and T is sequence length)
This is also relevant: https://github.com/mgmalek/efficient_cross_entropy/
@Chillee I considered this type of optimization before @imoneoi posted this issue. Similar optimization could be implemented easily with checkpoint, pseudo code:
loss = 0.
for x_chunk, y_chunk in zip(x.split(num_chunks), y.split(num_chunks)):
loss += checkpoint(lambda w, x, y: ce_loss(x @ w, y), w, x_chunk, y_chunk)
But this implementation has a drawback: It needs to load the huge
Using a custom kernel to avoid unnecessary re-computation in w multiple times, causes num_chunks times memory bandwidth cost.checkpoint is a clever idea!
This is also relevant: https://github.com/mgmalek/efficient_cross_entropy/
i just tested this with llama3 8B by monkey patching huggingface transformers' model class
- 1x 80GB A100
-
[B, T] = [1, 32768]input size - bf16
- flash attention (torch SDPA)
- offloading param and gradient to CPU (because i want to scale up with FSDP after profiling)
- activation checkpointing (in every layer and offload to CPU)
- fused CE loss (n_loop_iters=8)
i recorded GPU memory of 3 fwd-bwd steps form the beginning (i didn't care warmup)
and this is for [B, T] = [4, 20480] @imoneoi
@SeunghyunSEO Could you please include the memory usage curve of the baseline (original HF transformers' model)?
@YouJiacheng of course i can, but the original model's peak memory would be greater than 12GB, because full precision logits is already 15++ GB (B*T*vocab*float32 = 1*32768*128256*4)
Reasonable. I am just curious about the peak and fluctuation on the curve.
Read the benchmark result in the repo, it seems that loading the weight multiple times incurs nearly 0 cost.
-- Ah, I forgot that matrix multiplication with large matrices will load both matrices multiple times into SRAM. As long as the block size is large enough, memory won't be a bottleneck.
Following this (simplified) algorithm, x @ w will load w M / BLOCK_SIZE_M times if x.shape is (M, K). Thus, as long as the chunk size is a multiple of BLOCK_SIZE_M, there is no extra memory read.
In practice, the hierarchy is not only
SRAM-DRAM, it can be Shared_Memory-L2_Cache-DRAM, but the story is similar.
https://github.com/mgmalek/efficient_cross_entropy can even be implemented in pure pytorch without a custom kernel, but an in-place softmax is required to achieve the "Optimization 1". But if the loss needs to be computed with good numerical stability, a custom kernel will be unavoidable to reduce memory overhead (note: with in-place softmax we can get a numerical stable grad, but we cannot compute a numerical stable loss without log_softmax).
@YouJiacheng yeah yeah wll be uploaded!
This is also relevant: https://github.com/mgmalek/efficient_cross_entropy/
i just tested this with llama3 8B by monkey patching huggingface transformers' model class and this is for
[B, T] = [4, 20480]
@SeunghyunSEO Did you have to make changes to the kernel to work with batch sizes > 1?
This is also relevant: https://github.com/mgmalek/efficient_cross_entropy/
i just tested this with llama3 8B by monkey patching huggingface transformers' model class and this is for
[B, T] = [4, 20480]@SeunghyunSEO Did you have to make changes to the kernel to work with batch sizes > 1?
@winglian you can simply flatten these two axis/dim.
This is also relevant: https://github.com/mgmalek/efficient_cross_entropy/
i just tested this with llama3 8B by monkey patching huggingface transformers' model class and this is for
[B, T] = [4, 20480]@SeunghyunSEO Did you have to make changes to the kernel to work with batch sizes > 1?
no i didn't. like @YouJiacheng said, you can flatten last hidden. but it's not optional, it's essential. see this line
@SeunghyunSEO Could you please include the memory usage curve of the baseline (original HF transformers' model)?
@YouJiacheng this is the result for the case where all settings are the same but CE loss is not fused. i changed the original code little bit like this
logits = self.lm_head(hidden_states)
logits = logits.float()
logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
labels = labels[:, 1:].contiguous().view(-1).to(logits.device)
loss = CrossEntropyLoss()(logits, labels)
release_memory()
- 1x 80GB A100
-
[B, T] = [1, 32768]input size - bf16
- flash attention (torch SDPA)
- offloading param and gradient to CPU (because i want to scale up with FSDP after profiling)
- activation checkpointing (in every layer and offload to CPU)
i think fused kernel upcast logits to float32 too, so it's fair right?
added) since i recorded the first three steps (to see the change in GPU memory clearly), there may be noise in the time complexity measurement.
After reflection, I realize that chunking tokens is better than tiled softmax method for this use-case.
The tiled softmax method would require a re-computation of p in the backward pass, but chunking tokens with a clever custom kernel doesn't.
After reflection, I realize that chunking tokens is better than tiled softmax method for this use-case. The tiled softmax method would require a re-computation of
pin the backward pass, but chunking tokens with a clever custom kernel doesn't.
oh, i guess the performances of tiled softmax (like flash attention and your thought) and this kernel would be similar, but seems like this kernel is slightly better right?
btw i have one question for this kernel.
following this line, if i see num_iters=8, it does not allow some sequences with 1510 or something.
If so, should i pad hidden with -inf and labels with -100 ?
(there is nothing wrong with that in my opinion, but...)
(I'm afraid i wont be able to contribute well to this discussion because i'm new to low level optimization like triton, lol)
@SeunghyunSEO
This kernel doesn't support You can also tweak this line to the actual number of tokens in this chunk, and tweak this line to allow non-uniform chunking.ignore_index, currently.
The tiled softmax would require a re-computation, increase ~33% FLOPs.
This kernel can be even better than https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py since it only loads the logits ONCE, while other implementations load the logits TWICE (forward + backward).
This kernel doesn't support
ignore_index, currently. You can tweak [this line]
this line isn't sufficient to support ignore_index? https://github.com/mgmalek/efficient_cross_entropy/blob/049d44460051a82f58f7ff49a2ad0653ecf026d8/modules.py#L56
@winglian Sorry, I overlooked this line!
@SeunghyunSEO ~This kernel doesn't support
ignore_index, currently.~ You can also tweak this line to the actual number of tokens in this chunk, and tweak this line to allow non-uniform chunking.The tiled softmax would require a re-computation, increase ~33% FLOPs.
This kernel can be even better than https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py since it only loads the logits ONCE, while other implementations load the logits TWICE (forward + backward).
agree with that. what a kernel XD and like @winglian said, it looks like supporting ignore_index too!
@balancap I'm curious how you handle the back to back gemm in the backward pass if we don't want to materialize the large logits tensor (for better perf and for saving peak memory usage).
In the backward pass, we need recompute the logits with a gemm: [BxT, D] x [D, V] = [BxT, V]. (B batch size, T sequence length, D hidden dimension; V vocabulary size).
And then we need do 2 gemm to
- compute gradients for weights: [BxT, D].t x [BxT, V]
- compute gradients for inputs: [BxT, V] x [D, V].t
Potentially, we need do two back-to-back gemms here to compute gradients of inputs and weights.
I'm thinking we can do similar thing as flash attention. But the difference here is the tensors here are not as 'skinny' as flash attention: in flash attention, the size of each row of the tensor (parameter d in the paper) is per head embedding dimension, while here the size of each row is the hidden dimension (which is H times larger. H for number of heads). Maybe the back-to-back gemm can still be efficiently implemented. But I would like to ask if the larger row size here cause any trouble when you implement similar thing for Graphcore.
I tested mgmalek's kernel in a forked repo. https://github.com/kjslag/efficient_cross_entropy
Some tests fail when I include tests that use ignore_index.
This happened to be very relevant for a project of ours, so I spent (too much) time looking into it, recently.
Just for reference, I think what this thread needs the most is a few more benchmarks, so here are some baselines in TFLOPs, memory and accuracy, when moving over N (number of tokens), V (vocab size) and H (hidden dim) for default values of N=16384, H=2048, V=131072:
triton-z-chunks-in-sram is https://gist.github.com/JonasGeiping/6b724907ceb35555a6168dda9b9c4136, which is a variation of Malek's code (linked above), but more accurate (and sometimes faster, but also with more autotune behind it). I think this variant could be even faster if the prologue is fused into each of the two backward matmuls, but that's beyond my triton abilities to implement efficiently.
The torch checkpoint baseline is
def torch_checkpoint(x, y, A, default_chunk_size=512):
loss = 0.0
_, H = A.shape
N = x.view(-1, H).shape[0]
chunk_size = min(default_chunk_size, N)
if chunk_size % N != 0:
chunk_size = math.gcd(N, default_chunk_size)
x_blocks = x.view(-1, H).split(chunk_size)
y_blocks = y.view(-1).split(chunk_size)
for x_block, y_block in zip(x_blocks, y_blocks):
loss += checkpoint(_inner_function, x_block, y_block, A, num_blocks=len(y_blocks),
use_reentrant=False)
return loss
(which needs a rewrite to reliably torch.compile)
Overall, having the logits z=Ax stored in SRAM is a bit of an unsatisfying solution to me, they need to be in float32 for precision until softmax(z) = (z - lse).exp() is computed, and so they introduce a lot of memory transfer, only to be immediately discarded, but I couldn't make other versions work.
The problem is, as also discussed above that is H is just too large. The forward pass is nice and fast, but in the backward pass, rematerializing z=Ax requires parallelization in (N,V) and accumulation in H. But, the backward gradients requires parallelization in (N,H) and (V,H), respectively. The other constraint that we need matrix shapes, so the gradient shapes should be at least, e.g. for dx: 16 x H. If this fits, a single operation can recompute z and immediately compute dx and dA. It doesn't fit for normal H shapes like 4096, leading to a ton of memory load and write traffic (see failed attempts here or here).
https://github.com/JonasGeiping/linear_cross_entropy_loss
I have a question, do we need to compute the forward pass result? I think we can directly fuse all the things with a backward kernel during training, and thus, there is no need to save the result?
Yeah, both Malek's version and mine (both linked above), directly compute gradients during the forward pass loop over N and discard anything else. During the backward pass, nothing happens except for a scalar multiplication.
Great!
Yeah, both Malek's version and mine (both linked above), directly compute gradients during the forward pass loop over N and discard anything else. During the backward pass, nothing happens except for a scalar multiplication.