pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

Fused Linear and Cross-Entropy Loss `torch.nn.functional.linear_cross_entropy`

Open imoneoi opened this issue 2 years ago • 23 comments

🚀 The feature, motivation and pitch

It'd be great to have a fused linear and cross-entropy function in PyTorch, for example, torch.nn.functional.linear_cross_entropy. This function acts as a fused linear projection followed by a cross-entropy loss, e.g.

def linear_cross_entropy(linear_weights, input, labels):
    logits = F.linear(input, linear_weights)
    return F.cross_entropy(logits, labels)

Compared to naive implementation, this fused function does not materialize the intermediate logits, saving a lot of VRAM.

This function is quite common in models, such as LLMs, classification, and so on. When the batch size or number of classes is large, the intermediate logits would take up the majority of VRAM.

For example, the latest LLMs, such as Llama 3 and Gemma, adopt extremely large vocabularies (128-256K). Thus, the size of logits can become very large, consuming a significant proportion of VRAM. The calculation below shows the VRAM size for Llama 3 8B with a batch size of 81920 tokens:

Logits size: 128256 (vocabulary size) * 81920 (batch size) * 2 bytes (bf16) = 19.57GiB
Hidden state size (checkpointed): 4096 (hidden size) * 81920 (batch size) * 32 (layers) * 2 bytes (bf16) = 20GiB

Therefore, a fused linear and cross-entropy loss function that does not require materializing full logits may reduce VRAM consumption by half.

Related: the same feature request for FlashAttention: https://github.com/Dao-AILab/flash-attention/issues/922

Alternatives

No response

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

imoneoi avatar Apr 19 '24 14:04 imoneoi

cc: @Chillee would be curious to hear your thoughts

awgu avatar Apr 19 '24 14:04 awgu

Funny thing, I have been working on exactly that very recently (for Graphcore IPU, but nothing specific about the hardware in the technique). It's definitely possible to serialize & fuse the linear + softmax + cross entropy combination, the only trade-off being the recomputation necessary for the backward pass. Happy to help!

balancap avatar Apr 19 '24 14:04 balancap

FWIW, I derived some formulas & wrote a (very rough) WIP prototype.

import torch

import triton
import triton.language as tl

@triton.jit
def fwd_kernel(
    x_ND_ptr,
    w_DV_ptr,
    c_N_ptr,
    output_N_ptr,
    l_N_ptr,
    N, D, V,
    stride_xn, stride_xd,
    stride_wd, stride_wv,
    # Meta-parameters
    BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_V: tl.constexpr,
):
    # TODO: more parallelism, e.g. tiled softmax 
    #       w/ parallelization across tiles (intra-tile computation uses online softmax)
    # TODO: mask
    # only parallelize along the N dimension
    # i is the same as n
    i = tl.program_id(axis=0)
    offs_n_bN = i * BLOCK_N + tl.arange(0, BLOCK_N)
    c_i_bN = tl.load(c_N_ptr + offs_n_bN)
    output_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32)

    # statistics for online softmax
    m_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) - float('inf')
    l_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) + 1.0
    for start_v in range(0, V, BLOCK_V):
        start_v = tl.multiple_of(start_v, BLOCK_V)
        offs_v_bN = start_v + tl.arange(0, BLOCK_V)
        # TODO: mask
        x_ND_block_ptr = tl.make_block_ptr(
            base=x_ND_ptr,
            shape=(N, D),
            strides=(stride_xn, stride_xd),
            offsets=(i * BLOCK_N, 0),
            block_shape=(BLOCK_N, BLOCK_D),
            order=(1, 0),
        )
        w_DV_block_ptr = tl.make_block_ptr(
            base=w_DV_ptr,
            shape=(D, V),
            strides=(stride_wd, stride_wv),
            offsets=(0, start_v),
            block_shape=(BLOCK_D, BLOCK_V),
            order=(1, 0),
        )
        xw_bNbV = tl.zeros([BLOCK_N, BLOCK_V], dtype=tl.float32)
        for start_d in range(0, D, BLOCK_D):
            start_d = tl.multiple_of(start_d, BLOCK_D)
            # TODO: mask
            # TODO: x load can be reduced?
            x_bNbD = tl.load(x_ND_block_ptr)
            w_bDbV = tl.load(w_DV_block_ptr)
            xw_bNbV = tl.dot(x_bNbD, w_bDbV, xw_bNbV)
            x_ND_block_ptr = tl.advance(x_ND_block_ptr, (0, BLOCK_D))
            w_DV_block_ptr = tl.advance(w_DV_block_ptr, (BLOCK_D, 0))
        
        # i for N
        # j for V
        m_ij_bN = tl.maximum(m_i_bN, tl.max(xw_bNbV, axis=1))
        p_ij_bNbV = tl.exp(xw_bNbV - m_ij_bN[:, None])
        l_ij_bN = tl.sum(p_ij_bNbV, axis=1)
        # update m_i and l_i
        alpha_bN = tl.exp(m_i_bN - m_ij_bN)
        l_i_bN = l_i_bN * alpha_bN + l_ij_bN
        m_i_bN = m_ij_bN
        # update output
        p_ic_bN = tl.sum(tl.where(c_i_bN[:, None] == offs_v_bN[None, :], p_ij_bNbV, 0.0), axis=1)
        output_i_bN = output_i_bN * alpha_bN + p_ic_bN

    output_i_bN = tl.log(output_i_bN) - tl.log(l_i_bN)
    tl.store(output_N_ptr + offs_n_bN, output_i_bN)
    tl.store(l_N_ptr + offs_n_bN, l_i_bN)


# output_N[n] = log_softmax(x_ND @ w_DV, dim=1)[n, c_N[n]]
class LMHeadThenLogSoftmaxThenGather(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_ND, w_DV, c_N):
        # TODO
        l_N = None 
        ctx.save_for_backward(w_DV, x_ND, c_N, l_N)
        

    @staticmethod
    def backward(ctx, g_NV):
        # TODO
        w_DV, x_ND, c_N, l_N = ctx.saved_tensors

D, V, N = 10, 20, 30
w_DV = torch.randn(D, V, requires_grad=True)
x_NV = torch.randn(N, D, requires_grad=True)
c_N = torch.randint(V, (N,), dtype=torch.int64)

image

YouJiacheng avatar Apr 19 '24 17:04 YouJiacheng

cc @eellison from triage meeting: is this something we can pattern match for in compile/is it already handled by the scheduler in compile?

mikaylagawarecki avatar Apr 19 '24 22:04 mikaylagawarecki

Update: forward finished

import torch

import triton
import triton.language as tl

@triton.jit
def fwd_kernel(
    x_ND_ptr,
    w_DV_ptr,
    c_N_ptr,
    output_N_ptr,
    l_N_ptr,
    N, D, V,
    stride_xn, stride_xd,
    stride_wd, stride_wv,
    # Meta-parameters
    BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_V: tl.constexpr,
):
    # TODO: more parallelism, e.g. tiled softmax 
    #       w/ parallelization across tiles (intra-tile computation uses online softmax)
    # TODO: mask
    # only parallelize along the N dimension
    # i is the same as n
    i = tl.program_id(axis=0)
    offs_n_bN = i * BLOCK_N + tl.arange(0, BLOCK_N)
    c_i_bN = tl.load(c_N_ptr + offs_n_bN)
    output_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32)

    # statistics for online softmax
    m_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) - float('inf')
    l_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) + 1.0
    for start_v in range(0, V, BLOCK_V):
        start_v = tl.multiple_of(start_v, BLOCK_V)
        offs_v_bN = start_v + tl.arange(0, BLOCK_V)
        # TODO: mask
        x_ND_block_ptr = tl.make_block_ptr(
            base=x_ND_ptr,
            shape=(N, D),
            strides=(stride_xn, stride_xd),
            offsets=(i * BLOCK_N, 0),
            block_shape=(BLOCK_N, BLOCK_D),
            order=(1, 0),
        )
        w_DV_block_ptr = tl.make_block_ptr(
            base=w_DV_ptr,
            shape=(D, V),
            strides=(stride_wd, stride_wv),
            offsets=(0, start_v),
            block_shape=(BLOCK_D, BLOCK_V),
            order=(1, 0),
        )
        xw_bNbV = tl.zeros([BLOCK_N, BLOCK_V], dtype=tl.float32)
        for start_d in range(0, D, BLOCK_D):
            start_d = tl.multiple_of(start_d, BLOCK_D)
            # TODO: mask
            # TODO: x load can be reduced?
            x_bNbD = tl.load(x_ND_block_ptr)
            w_bDbV = tl.load(w_DV_block_ptr)
            xw_bNbV = tl.dot(x_bNbD, w_bDbV, xw_bNbV)
            x_ND_block_ptr = tl.advance(x_ND_block_ptr, (0, BLOCK_D))
            w_DV_block_ptr = tl.advance(w_DV_block_ptr, (BLOCK_D, 0))
        
        # i for N
        # j for V
        m_ij_bN = tl.maximum(m_i_bN, tl.max(xw_bNbV, axis=1))
        p_ij_bNbV = tl.exp(xw_bNbV - m_ij_bN[:, None])
        l_ij_bN = tl.sum(p_ij_bNbV, axis=1)
        # update m_i and l_i
        alpha_bN = tl.exp(m_i_bN - m_ij_bN)
        l_i_bN = l_i_bN * alpha_bN + l_ij_bN
        m_i_bN = m_ij_bN
        # update output
        p_ic_bN = tl.sum(tl.where(c_i_bN[:, None] == offs_v_bN[None, :], p_ij_bNbV, 0.0), axis=1)
        output_i_bN = output_i_bN * alpha_bN + p_ic_bN

    output_i_bN = tl.log(output_i_bN) - tl.log(l_i_bN)
    tl.store(output_N_ptr + offs_n_bN, output_i_bN)
    tl.store(l_N_ptr + offs_n_bN, l_i_bN)


# output_N[n] = log_softmax(x_ND @ w_DV, dim=1)[n, c_N[n]]
class LMHeadThenLogSoftmaxThenGather(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_ND: torch.Tensor, w_DV: torch.Tensor, c_N: torch.Tensor):
        # TODO
        N, D = x_ND.shape
        Dw, V = w_DV.shape
        Nc, = c_N.shape
        assert D == Dw and N == Nc
        output_N = x_ND.new_empty(N)
        l_N = x_ND.new_empty(N)
        BLOCK_N = 32
        BLOCK_D = 32
        BLOCK_V = 32
        grid = (triton.cdiv(N, BLOCK_N),)
        fwd_kernel[grid](
            x_ND, w_DV, c_N,
            output_N, l_N,
            N, D, V,
            x_ND.stride(0), x_ND.stride(1),
            w_DV.stride(0), w_DV.stride(1),
            BLOCK_N, BLOCK_D, BLOCK_V,
        )
        ctx.save_for_backward(w_DV, x_ND, c_N, l_N)
        return output_N
        

    @staticmethod
    def backward(ctx, g_NV):
        # TODO
        w_DV, x_ND, c_N, l_N = ctx.saved_tensors

D, V, N = 4096, 131072, 65536
w_DV = torch.empty(D, V, dtype=torch.bfloat16, device='cuda').normal_(std= D ** -0.5)
x_NV = torch.empty(N, D, dtype=torch.bfloat16, device='cuda').normal_()
c_N = torch.empty(N, dtype=torch.int64, device='cuda').random_(V)

output_N: torch.Tensor = LMHeadThenLogSoftmaxThenGather.apply(x_NV, w_DV, c_N)
ref = torch.nn.functional.log_softmax(x_NV @ w_DV, dim=1)[torch.arange(N), c_N].double()
err = abs(output_N.double() - ref) / abs(ref)
print(torch.mean(err).item(), torch.quantile(err, 0.95).item(), torch.quantile(err, 0.99).item())

Output: 9.208518284318959e-05 0.0 0.004878048780487805

YouJiacheng avatar Apr 26 '24 10:04 YouJiacheng

hi @imoneoi thank you for the raising the issue, I'm suffering from the same issue and would be very happy if this feature is added to torch ! i have a question about the batch size you mentioned, 81920. is this equivalent to the number of valid tokens used for computing loss ? ( B*T where B is batch size and T is sequence length)

SeunghyunSEO avatar May 02 '24 09:05 SeunghyunSEO

This is also relevant: https://github.com/mgmalek/efficient_cross_entropy/

Chillee avatar May 03 '24 22:05 Chillee

@Chillee I considered this type of optimization before @imoneoi posted this issue. Similar optimization could be implemented easily with checkpoint, pseudo code:

loss = 0.
for x_chunk, y_chunk in zip(x.split(num_chunks), y.split(num_chunks)):
    loss += checkpoint(lambda w, x, y: ce_loss(x @ w, y), w, x_chunk, y_chunk)

But this implementation has a drawback: It needs to load the huge w multiple times, causes num_chunks times memory bandwidth cost. Using a custom kernel to avoid unnecessary re-computation in checkpoint is a clever idea!

YouJiacheng avatar May 04 '24 07:05 YouJiacheng

This is also relevant: https://github.com/mgmalek/efficient_cross_entropy/

i just tested this with llama3 8B by monkey patching huggingface transformers' model class

  • 1x 80GB A100
  • [B, T] = [1, 32768] input size
  • bf16
  • flash attention (torch SDPA)
  • offloading param and gradient to CPU (because i want to scale up with FSDP after profiling)
  • activation checkpointing (in every layer and offload to CPU)
  • fused CE loss (n_loop_iters=8)

i recorded GPU memory of 3 fwd-bwd steps form the beginning (i didn't care warmup)

Screenshot 2024-05-04 at 6 25 33 PM

and this is for [B, T] = [4, 20480] @imoneoi

Screenshot 2024-05-04 at 7 07 52 PM

SeunghyunSEO avatar May 04 '24 09:05 SeunghyunSEO

@SeunghyunSEO Could you please include the memory usage curve of the baseline (original HF transformers' model)?

YouJiacheng avatar May 04 '24 11:05 YouJiacheng

@YouJiacheng of course i can, but the original model's peak memory would be greater than 12GB, because full precision logits is already 15++ GB (B*T*vocab*float32 = 1*32768*128256*4)

SeunghyunSEO avatar May 04 '24 12:05 SeunghyunSEO

Reasonable. I am just curious about the peak and fluctuation on the curve.

Read the benchmark result in the repo, it seems that loading the weight multiple times incurs nearly 0 cost.

-- Ah, I forgot that matrix multiplication with large matrices will load both matrices multiple times into SRAM. As long as the block size is large enough, memory won't be a bottleneck.

Following this (simplified) algorithm, x @ w will load w M / BLOCK_SIZE_M times if x.shape is (M, K). Thus, as long as the chunk size is a multiple of BLOCK_SIZE_M, there is no extra memory read. image In practice, the hierarchy is not only SRAM-DRAM, it can be Shared_Memory-L2_Cache-DRAM, but the story is similar.

https://github.com/mgmalek/efficient_cross_entropy can even be implemented in pure pytorch without a custom kernel, but an in-place softmax is required to achieve the "Optimization 1". But if the loss needs to be computed with good numerical stability, a custom kernel will be unavoidable to reduce memory overhead (note: with in-place softmax we can get a numerical stable grad, but we cannot compute a numerical stable loss without log_softmax).

YouJiacheng avatar May 04 '24 12:05 YouJiacheng

@YouJiacheng yeah yeah wll be uploaded!

SeunghyunSEO avatar May 04 '24 12:05 SeunghyunSEO

This is also relevant: https://github.com/mgmalek/efficient_cross_entropy/

i just tested this with llama3 8B by monkey patching huggingface transformers' model class and this is for [B, T] = [4, 20480]

@SeunghyunSEO Did you have to make changes to the kernel to work with batch sizes > 1?

winglian avatar May 04 '24 14:05 winglian

This is also relevant: https://github.com/mgmalek/efficient_cross_entropy/

i just tested this with llama3 8B by monkey patching huggingface transformers' model class and this is for [B, T] = [4, 20480]

@SeunghyunSEO Did you have to make changes to the kernel to work with batch sizes > 1?

@winglian you can simply flatten these two axis/dim.

YouJiacheng avatar May 04 '24 14:05 YouJiacheng

This is also relevant: https://github.com/mgmalek/efficient_cross_entropy/

i just tested this with llama3 8B by monkey patching huggingface transformers' model class and this is for [B, T] = [4, 20480]

@SeunghyunSEO Did you have to make changes to the kernel to work with batch sizes > 1?

no i didn't. like @YouJiacheng said, you can flatten last hidden. but it's not optional, it's essential. see this line

SeunghyunSEO avatar May 04 '24 14:05 SeunghyunSEO

@SeunghyunSEO Could you please include the memory usage curve of the baseline (original HF transformers' model)?

@YouJiacheng this is the result for the case where all settings are the same but CE loss is not fused. i changed the original code little bit like this

        logits = self.lm_head(hidden_states)
        logits = logits.float()
        logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
        labels = labels[:, 1:].contiguous().view(-1).to(logits.device)
        loss = CrossEntropyLoss()(logits, labels)
        release_memory()
  • 1x 80GB A100
  • [B, T] = [1, 32768] input size
  • bf16
  • flash attention (torch SDPA)
  • offloading param and gradient to CPU (because i want to scale up with FSDP after profiling)
  • activation checkpointing (in every layer and offload to CPU)

i think fused kernel upcast logits to float32 too, so it's fair right?

Screenshot 2024-05-04 at 11 46 13 PM

added) since i recorded the first three steps (to see the change in GPU memory clearly), there may be noise in the time complexity measurement.

SeunghyunSEO avatar May 04 '24 14:05 SeunghyunSEO

After reflection, I realize that chunking tokens is better than tiled softmax method for this use-case. The tiled softmax method would require a re-computation of p in the backward pass, but chunking tokens with a clever custom kernel doesn't.

YouJiacheng avatar May 04 '24 14:05 YouJiacheng

After reflection, I realize that chunking tokens is better than tiled softmax method for this use-case. The tiled softmax method would require a re-computation of p in the backward pass, but chunking tokens with a clever custom kernel doesn't.

oh, i guess the performances of tiled softmax (like flash attention and your thought) and this kernel would be similar, but seems like this kernel is slightly better right?

btw i have one question for this kernel. following this line, if i see num_iters=8, it does not allow some sequences with 1510 or something. If so, should i pad hidden with -inf and labels with -100 ?
(there is nothing wrong with that in my opinion, but...)

(I'm afraid i wont be able to contribute well to this discussion because i'm new to low level optimization like triton, lol)

SeunghyunSEO avatar May 04 '24 15:05 SeunghyunSEO

@SeunghyunSEO This kernel doesn't support ignore_index, currently. You can also tweak this line to the actual number of tokens in this chunk, and tweak this line to allow non-uniform chunking.

The tiled softmax would require a re-computation, increase ~33% FLOPs.

This kernel can be even better than https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py since it only loads the logits ONCE, while other implementations load the logits TWICE (forward + backward).

YouJiacheng avatar May 04 '24 15:05 YouJiacheng

This kernel doesn't support ignore_index, currently. You can tweak [this line]

this line isn't sufficient to support ignore_index? https://github.com/mgmalek/efficient_cross_entropy/blob/049d44460051a82f58f7ff49a2ad0653ecf026d8/modules.py#L56

winglian avatar May 04 '24 15:05 winglian

@winglian Sorry, I overlooked this line!

YouJiacheng avatar May 04 '24 15:05 YouJiacheng

@SeunghyunSEO ~This kernel doesn't support ignore_index, currently.~ You can also tweak this line to the actual number of tokens in this chunk, and tweak this line to allow non-uniform chunking.

The tiled softmax would require a re-computation, increase ~33% FLOPs.

This kernel can be even better than https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py since it only loads the logits ONCE, while other implementations load the logits TWICE (forward + backward).

agree with that. what a kernel XD and like @winglian said, it looks like supporting ignore_index too!

SeunghyunSEO avatar May 04 '24 15:05 SeunghyunSEO

@balancap I'm curious how you handle the back to back gemm in the backward pass if we don't want to materialize the large logits tensor (for better perf and for saving peak memory usage).

In the backward pass, we need recompute the logits with a gemm: [BxT, D] x [D, V] = [BxT, V]. (B batch size, T sequence length, D hidden dimension; V vocabulary size).

And then we need do 2 gemm to

  • compute gradients for weights: [BxT, D].t x [BxT, V]
  • compute gradients for inputs: [BxT, V] x [D, V].t

Potentially, we need do two back-to-back gemms here to compute gradients of inputs and weights.

I'm thinking we can do similar thing as flash attention. But the difference here is the tensors here are not as 'skinny' as flash attention: in flash attention, the size of each row of the tensor (parameter d in the paper) is per head embedding dimension, while here the size of each row is the hidden dimension (which is H times larger. H for number of heads). Maybe the back-to-back gemm can still be efficiently implemented. But I would like to ask if the larger row size here cause any trouble when you implement similar thing for Graphcore.

shunting314 avatar May 07 '24 00:05 shunting314

I tested mgmalek's kernel in a forked repo. https://github.com/kjslag/efficient_cross_entropy

Some tests fail when I include tests that use ignore_index.

kjslag avatar May 28 '24 14:05 kjslag

This happened to be very relevant for a project of ours, so I spent (too much) time looking into it, recently.

Just for reference, I think what this thread needs the most is a few more benchmarks, so here are some baselines in TFLOPs, memory and accuracy, when moving over N (number of tokens), V (vocab size) and H (hidden dim) for default values of N=16384, H=2048, V=131072:

fwd-bwd-Linear+Loss Performance over H  Defaults: N=B*S=16384, H=2048, V=131072 fwd-bwd-Linear+Loss Performance over N  Defaults: N=B*S=16384, H=2048, V=131072 fwd-bwd-Linear+Loss Performance over V  Defaults: N=B*S=16384, H=2048, V=131072


fwd-bwd-Linear+Loss Memory Peak over H  Defaults: N=B*S=16384, H=2048, V=131072 fwd-bwd-Linear+Loss Memory Peak over N  Defaults: N=B*S=16384, H=2048, V=131072 fwd-bwd-Linear+Loss Memory Peak over V  Defaults: N=B*S=16384, H=2048, V=131072


Linear+Loss Accuracy over H  Defaults: N=B*S=16384, H=2048, V=131072 Linear+Loss Accuracy over N  Defaults: N=B*S=16384, H=2048, V=131072 Linear+Loss Accuracy over V  Defaults: N=B*S=16384, H=2048, V=131072

triton-z-chunks-in-sram is https://gist.github.com/JonasGeiping/6b724907ceb35555a6168dda9b9c4136, which is a variation of Malek's code (linked above), but more accurate (and sometimes faster, but also with more autotune behind it). I think this variant could be even faster if the prologue is fused into each of the two backward matmuls, but that's beyond my triton abilities to implement efficiently.

The torch checkpoint baseline is

def torch_checkpoint(x, y, A, default_chunk_size=512):
    loss = 0.0
    _, H = A.shape
    N = x.view(-1, H).shape[0]
    chunk_size = min(default_chunk_size, N)
    if chunk_size % N != 0:
        chunk_size = math.gcd(N, default_chunk_size)
    x_blocks = x.view(-1, H).split(chunk_size)
    y_blocks = y.view(-1).split(chunk_size)

    for x_block, y_block in zip(x_blocks, y_blocks):
        loss += checkpoint(_inner_function, x_block, y_block, A, num_blocks=len(y_blocks), 
                           use_reentrant=False)
    return loss

(which needs a rewrite to reliably torch.compile)


Overall, having the logits z=Ax stored in SRAM is a bit of an unsatisfying solution to me, they need to be in float32 for precision until softmax(z) = (z - lse).exp() is computed, and so they introduce a lot of memory transfer, only to be immediately discarded, but I couldn't make other versions work.

The problem is, as also discussed above that is H is just too large. The forward pass is nice and fast, but in the backward pass, rematerializing z=Ax requires parallelization in (N,V) and accumulation in H. But, the backward gradients requires parallelization in (N,H) and (V,H), respectively. The other constraint that we need matrix shapes, so the gradient shapes should be at least, e.g. for dx: 16 x H. If this fits, a single operation can recompute z and immediately compute dx and dA. It doesn't fit for normal H shapes like 4096, leading to a ton of memory load and write traffic (see failed attempts here or here).

JonasGeiping avatar May 30 '24 18:05 JonasGeiping

https://github.com/JonasGeiping/linear_cross_entropy_loss

152334H avatar Jul 20 '24 03:07 152334H

I have a question, do we need to compute the forward pass result? I think we can directly fuse all the things with a backward kernel during training, and thus, there is no need to save the result?

NonvolatileMemory avatar Aug 15 '24 06:08 NonvolatileMemory

Yeah, both Malek's version and mine (both linked above), directly compute gradients during the forward pass loop over N and discard anything else. During the backward pass, nothing happens except for a scalar multiplication.

JonasGeiping avatar Aug 15 '24 09:08 JonasGeiping

Great!

Yeah, both Malek's version and mine (both linked above), directly compute gradients during the forward pass loop over N and discard anything else. During the backward pass, nothing happens except for a scalar multiplication.

NonvolatileMemory avatar Aug 15 '24 10:08 NonvolatileMemory