flash-attention-jax icon indicating copy to clipboard operation
flash-attention-jax copied to clipboard

fix compatibility with jax transformations

Open GallagherCommaJack opened this issue 3 years ago • 28 comments

currently impossible to use flash_attention within a function that will use gradient checkpointing

minimal example to reproduce:

b = 3
lq = 16
lkv = 17
h = 5
d = 19
q = jax.random.normal(keys[0], (b, lq, h, d))
k = jax.random.normal(keys[1], (b, lkv, h, d))
v = jax.random.normal(keys[2], (b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (b, lkv))

@jax.jit
def bench_flash_bwd(q, k, v, mask):
    return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)

fails with error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb Cell 9 in <cell line: 1>()
----> [1](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) get_ipython().run_line_magic('timeit', 'bench_flash_bwd(q, k, v, mask).block_until_ready()')

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2305, in InteractiveShell.run_line_magic(self, magic_name, line, _stack_depth)
   2303     kwargs['local_ns'] = self.get_local_scope(stack_depth)
   2304 with self.builtin_trap:
-> 2305     result = fn(*args, **kwargs)
   2306 return result

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/magics/execution.py:1162, in ExecutionMagics.timeit(self, line, cell, local_ns)
   1160 for index in range(0, 10):
   1161     number = 10 ** index
-> 1162     time_number = timer.timeit(number)
   1163     if time_number >= 0.2:
   1164         break

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/magics/execution.py:156, in Timer.timeit(self, number)
    154 gc.disable()
    155 try:
--> 156     timing = self.inner(it, self.timer)
    157 finally:
    158     if gcold:

File <magic-timeit>:1, in inner(_it, _timer)

    [... skipping hidden 14 frame]

/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb Cell 9 in bench_flash_bwd(q, k, v, mask)
      [1](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) @jax.jit
      [2](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1) def bench_flash_bwd(q, k, v, mask):
----> [3](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2)     return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0]), policy=jax.checkpoint_policies.everything_saveable))(q)

    [... skipping hidden 25 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/util.py:48, in safe_map(f, *args)
     46 n = len(args[0])
     47 for arg in args[1:]:
---> 48   assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
     49 return list(map(f, *args))

AssertionError: length mismatch: [3, 1]

GallagherCommaJack avatar Sep 23 '22 17:09 GallagherCommaJack

can confirm that this error also appears under jax.lax.scan

example here:

q = jax.random.normal(keys[0], (l, b, lq, h, d))
k = jax.random.normal(keys[1], (l, b, lkv, h, d))
v = jax.random.normal(keys[2], (l, b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (l, b, lkv))


def scan_fn(carry, qkv):
    out = flash_attention(*qkv)[0]
    carry += out
    return carry, out


@jax.jit
def bench_flash_bwd(q, k, v, mask):
    return jax.grad(
        lambda q, k, v, mask: jnp.sum(
            jax.lax.scan(
                scan_fn,
                jnp.zeros_like(q[0]),
                (q, k, v, mask),
            )[0],
        )
    )(q, k, v, mask)


bench_flash_bwd(q, k, v, mask)

GallagherCommaJack avatar Sep 29 '22 20:09 GallagherCommaJack

Thanks for raising this! It looks like a JAX core bug most likely.

Could you provide a self-contained runnable repro, in particular including the import or definition for flash_attention? (Sorry, I'm not the developer of this repo, so I'm not familiar with that function.)

mattjj avatar Sep 29 '22 20:09 mattjj

from flash_attention_jax import flash_attention

GallagherCommaJack avatar Sep 29 '22 21:09 GallagherCommaJack

ran into this and failed to upstream. The trick to fix it is to basically do this:

https://github.com/stanford-crfm/levanter/commit/a2828cea3e43700d0f6ead660228f2ec8e6f8c97#diff-658abe908dd5cd256efe9370e7ec2ae9fa2dcdca586a5f886940331e7b56dd09R129-R132

dlwh avatar Sep 29 '22 21:09 dlwh

@dlwh looks like you also ran an autoformatter so there's a ton of other changes here - can you say a bit more about how you fixed it?

GallagherCommaJack avatar Sep 29 '22 21:09 GallagherCommaJack

Yeah sorry, the line linked is the key one. Basically just rename the method called "causal_flash_attention" to "_causal_flash_attention" and make causal_flash_attention return just the first result. Then make flash_attention_forward call _causal_flash_attention instead, and you're done.

@custom_vjp
def causal_flash_attention(q, k, v):
+    return _causal_flash_attention(q, k, v)[0]
+
+
+def _causal_flash_attention(q, k, v):

dlwh avatar Sep 29 '22 21:09 dlwh

won't that make flash_attention always do causal masking? I'm using this in a context where that's not appropriate

GallagherCommaJack avatar Sep 29 '22 21:09 GallagherCommaJack

This is roughly repeating what @dlwh just said, but I just figured it out and came back to explain: this use of custom_vjp is buggy in that the flash_attention_forward output needs to be a pair where the first element has the same type as the output of flash_attention. Yet we can see that where flash_attention includes three arrays, the first element of the return value of flash_attention_forward only has one array.

There's a JAX bug in that this was a terrible error message to raise, but the fundamental bug is in that use of custom_vjp.

mattjj avatar Sep 29 '22 21:09 mattjj

you'll need to make the analogous change to flash_attention then. as @mattjj said it's really just a buggy use of custom_vjp. (Though despite it not running the code was otherwise correct according to my gradient testing!)

dlwh avatar Sep 29 '22 21:09 dlwh

Shall I send a PR fix to this repo (maybe you both could review it), and then separately fix the JAX error message? Or @dlwh do you want to send the fix to this repo?

mattjj avatar Sep 29 '22 21:09 mattjj

I can probably get to it tonight or tomorrow, but I'm about to go dark for several hours. Totally up to you!

dlwh avatar Sep 29 '22 21:09 dlwh

I'll take the first stab, and cc you!

mattjj avatar Sep 29 '22 21:09 mattjj

so the relevant fix would be to replace https://github.com/lucidrains/flash-attention-jax/blob/e5efb902c89858c2fa15fce2781492a6b36ddeb4/flash_attention_jax/flash_attention.py#L97 with

    return (out, (row_sum, row_max)), (q, k, v, key_mask, out, row_sum, row_max)

?

GallagherCommaJack avatar Sep 29 '22 21:09 GallagherCommaJack

interesting that this works with grad outside of scan and remat - probably it should fail under grad alone without either of those?

GallagherCommaJack avatar Sep 29 '22 21:09 GallagherCommaJack

@GallagherCommaJack Yes, that'd work! It's probably the simplest fix, though we could also look at the call sites of flash_attention to see if some other organization would be more natural.

What's a repro for the behavior you're describing? I tried removing jax.checkpoint from the repro in the OP and I still got an error. That is, this still errors for me:

import jax
import jax.numpy as jnp

from flash_attention_jax import flash_attention


b = 3
lq = 16
lkv = 17
h = 5
d = 19
keys = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(keys[0], (b, lq, h, d))
k = jax.random.normal(keys[1], (b, lkv, h, d))
v = jax.random.normal(keys[2], (b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (b, lkv))

@jax.jit
def bench_flash_bwd(q, k, v, mask):
    return jax.grad(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0]))(q)


bench_flash_bwd(q, k, v, mask)

mattjj avatar Sep 29 '22 21:09 mattjj

Ah, I think it was just a shape bug; if I sent lq = lvk = 16 then I see what you mean.

I think by adding the better JAX error message I described, we'll catch this much earlier and get an error in both cases. I'll be sure to test both with and without checkpoint/scan.

mattjj avatar Sep 29 '22 21:09 mattjj

Yes, that'd work!

Actually, I think it would not work just because the callers expect only a single output there.

I think the issue here was that the custom_vjp-decorated function (ie the "primal function") didn't agree with the custom_vjp rule (i.e. their output types didn't agree in the way that they should), but when we only use grad (possibly together with jit) we never actually run the primal function; we only run its forward rule. When grad is applied, we only actually run the primal function when under a jax.checkpoint or jax.scan (or jax.cond etc); that's just because of a JAX implementation detail (these are "initial-style higher-order primitives") which is usually invisible, except apparently when there's a type error in a custom_vjp rule!

mattjj avatar Sep 29 '22 21:09 mattjj

with the fix it's working with lq = lkv under jax.checkpoint! still fails with lq != lkv which I'm trying to debug now

GallagherCommaJack avatar Sep 29 '22 21:09 GallagherCommaJack

https://github.com/midjourney/flash-attention-jax/commit/f690412199178bc60fb4a768f28bffb2f27654cb

GallagherCommaJack avatar Sep 29 '22 21:09 GallagherCommaJack

the error with lq = 16; lkv = 17 is TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).

full backtrace:

TypeError                                 Traceback (most recent call last)
Cell In [5], line 22
     18 @jax.jit
     19 def bench_flash_bwd(q, k, v, mask):
     20     return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)
---> 22 bench_flash_bwd(q, k, v, mask)

    [... skipping hidden 14 frame]

Cell In [5], line 20, in bench_flash_bwd(q, k, v, mask)
     18 @jax.jit
     19 def bench_flash_bwd(q, k, v, mask):
---> 20     return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)

    [... skipping hidden 30 frame]

File ~/code/flash-attention-jax/flash_attention_jax/flash_attention.py:172, in flash_attention_backward(res, do)
    169     dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
    170     return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk
--> 172 (_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
    174 dq = rearrange(dq, 'c n b h d -> b h (c n) d')
    175 dk, dv = map(lambda t: rearrange(t, 'n b h d -> b h n d'), (dk, dv))

    [... skipping hidden 11 frame]

File ~/code/flash-attention-jax/flash_attention_jax/flash_attention.py:170, in flash_attention_backward.<locals>.chunk_scanner(carries, _)
    167 do_chunk = lax.dynamic_slice(do, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, do.shape[-1]))
    169 dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
--> 170 return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk

    [... skipping hidden 1 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4658, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
   4656 args = (other, self) if swap else (self, other)
   4657 if isinstance(other, _accepted_binop_types):
-> 4658   return binary_op(*args)
   4659 if isinstance(other, _rejected_binop_types):
   4660   raise TypeError(f"unsupported operand type(s) for {opchar}: "
   4661                   f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")

    [... skipping hidden 7 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py:84, in _maybe_bool_binop.<locals>.fn(x1, x2)
     82 def fn(x1, x2):
     83   x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
---> 84   return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)

    [... skipping hidden 7 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/lax/lax.py:1537, in broadcasting_shape_rule(name, *avals)
   1535       result_shape.append(non_1s[0])
   1536     else:
-> 1537       raise TypeError(f'{name} got incompatible shapes for broadcasting: '
   1538                       f'{", ".join(map(str, map(tuple, shapes)))}.')
   1540 return tuple(result_shape)

TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).

GallagherCommaJack avatar Sep 29 '22 21:09 GallagherCommaJack

It looks like one of chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk has a shape error, in flash_attention_backward. (EDIT: I don't feel comfortable debugging that without learning what this code is actually doing, so hopefully someone who knows the code/algorithm can help!)

mattjj avatar Sep 29 '22 21:09 mattjj

debugging a bit, it looks like the issue is that dk has shape h, b, lkv, d and dk_chunk has shape h, b, lq, d

GallagherCommaJack avatar Sep 29 '22 22:09 GallagherCommaJack

@lucidrains looks like there's an implicit assumption somewhere in here that lq == lkv in the backwards pass, in _query_chunk_flash_attention_backward

GallagherCommaJack avatar Sep 29 '22 22:09 GallagherCommaJack

@GallagherCommaJack the fix I proposed in #8 is different from the commit you sent, just FYI.

mattjj avatar Sep 29 '22 22:09 mattjj

does that work with lq != lkv?

GallagherCommaJack avatar Sep 29 '22 22:09 GallagherCommaJack

looks like it does not

GallagherCommaJack avatar Sep 29 '22 22:09 GallagherCommaJack

Indeed I think the shape issue is unrelated.

mattjj avatar Sep 29 '22 23:09 mattjj

google/jax#12611 should improve the error message we got here! With the same repro (i.e. before the fix #7 was merged here), the error will be:

TypeError: Custom VJP fwd rule flash_attention_forward for function
flash_attention must produce a pair (list or tuple of length two) where the
first element represents the primal output (equal to the output of the
custom_vjp-decorated function flash_attention) and the second element
represents residuals (i.e. values stored from the forward pass for use on the
backward pass), but instead the fwd rule output's first element had
container/pytree structure:
    float32[3,16,5,19]
while the custom_vjp-decorated function flash_attention had output
container/pytree structure:
    (float32[3,16,5,19], (float32[3,16,5], float32[3,16,5])).

mattjj avatar Oct 01 '22 04:10 mattjj