fix compatibility with jax transformations
currently impossible to use flash_attention within a function that will use gradient checkpointing
minimal example to reproduce:
b = 3
lq = 16
lkv = 17
h = 5
d = 19
q = jax.random.normal(keys[0], (b, lq, h, d))
k = jax.random.normal(keys[1], (b, lkv, h, d))
v = jax.random.normal(keys[2], (b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (b, lkv))
@jax.jit
def bench_flash_bwd(q, k, v, mask):
return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)
fails with error:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb Cell 9 in <cell line: 1>()
----> [1](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) get_ipython().run_line_magic('timeit', 'bench_flash_bwd(q, k, v, mask).block_until_ready()')
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2305, in InteractiveShell.run_line_magic(self, magic_name, line, _stack_depth)
2303 kwargs['local_ns'] = self.get_local_scope(stack_depth)
2304 with self.builtin_trap:
-> 2305 result = fn(*args, **kwargs)
2306 return result
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/magics/execution.py:1162, in ExecutionMagics.timeit(self, line, cell, local_ns)
1160 for index in range(0, 10):
1161 number = 10 ** index
-> 1162 time_number = timer.timeit(number)
1163 if time_number >= 0.2:
1164 break
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/magics/execution.py:156, in Timer.timeit(self, number)
154 gc.disable()
155 try:
--> 156 timing = self.inner(it, self.timer)
157 finally:
158 if gcold:
File <magic-timeit>:1, in inner(_it, _timer)
[... skipping hidden 14 frame]
/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb Cell 9 in bench_flash_bwd(q, k, v, mask)
[1](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) @jax.jit
[2](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1) def bench_flash_bwd(q, k, v, mask):
----> [3](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2) return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0]), policy=jax.checkpoint_policies.everything_saveable))(q)
[... skipping hidden 25 frame]
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/util.py:48, in safe_map(f, *args)
46 n = len(args[0])
47 for arg in args[1:]:
---> 48 assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
49 return list(map(f, *args))
AssertionError: length mismatch: [3, 1]
can confirm that this error also appears under jax.lax.scan
example here:
q = jax.random.normal(keys[0], (l, b, lq, h, d))
k = jax.random.normal(keys[1], (l, b, lkv, h, d))
v = jax.random.normal(keys[2], (l, b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (l, b, lkv))
def scan_fn(carry, qkv):
out = flash_attention(*qkv)[0]
carry += out
return carry, out
@jax.jit
def bench_flash_bwd(q, k, v, mask):
return jax.grad(
lambda q, k, v, mask: jnp.sum(
jax.lax.scan(
scan_fn,
jnp.zeros_like(q[0]),
(q, k, v, mask),
)[0],
)
)(q, k, v, mask)
bench_flash_bwd(q, k, v, mask)
Thanks for raising this! It looks like a JAX core bug most likely.
Could you provide a self-contained runnable repro, in particular including the import or definition for flash_attention? (Sorry, I'm not the developer of this repo, so I'm not familiar with that function.)
from flash_attention_jax import flash_attention
ran into this and failed to upstream. The trick to fix it is to basically do this:
https://github.com/stanford-crfm/levanter/commit/a2828cea3e43700d0f6ead660228f2ec8e6f8c97#diff-658abe908dd5cd256efe9370e7ec2ae9fa2dcdca586a5f886940331e7b56dd09R129-R132
@dlwh looks like you also ran an autoformatter so there's a ton of other changes here - can you say a bit more about how you fixed it?
Yeah sorry, the line linked is the key one. Basically just rename the method called "causal_flash_attention" to "_causal_flash_attention" and make causal_flash_attention return just the first result. Then make flash_attention_forward call _causal_flash_attention instead, and you're done.
@custom_vjp
def causal_flash_attention(q, k, v):
+ return _causal_flash_attention(q, k, v)[0]
+
+
+def _causal_flash_attention(q, k, v):
won't that make flash_attention always do causal masking? I'm using this in a context where that's not appropriate
This is roughly repeating what @dlwh just said, but I just figured it out and came back to explain: this use of custom_vjp is buggy in that the flash_attention_forward output needs to be a pair where the first element has the same type as the output of flash_attention. Yet we can see that where flash_attention includes three arrays, the first element of the return value of flash_attention_forward only has one array.
There's a JAX bug in that this was a terrible error message to raise, but the fundamental bug is in that use of custom_vjp.
you'll need to make the analogous change to flash_attention then. as @mattjj said it's really just a buggy use of custom_vjp. (Though despite it not running the code was otherwise correct according to my gradient testing!)
Shall I send a PR fix to this repo (maybe you both could review it), and then separately fix the JAX error message? Or @dlwh do you want to send the fix to this repo?
I can probably get to it tonight or tomorrow, but I'm about to go dark for several hours. Totally up to you!
I'll take the first stab, and cc you!
so the relevant fix would be to replace https://github.com/lucidrains/flash-attention-jax/blob/e5efb902c89858c2fa15fce2781492a6b36ddeb4/flash_attention_jax/flash_attention.py#L97 with
return (out, (row_sum, row_max)), (q, k, v, key_mask, out, row_sum, row_max)
?
interesting that this works with grad outside of scan and remat - probably it should fail under grad alone without either of those?
@GallagherCommaJack Yes, that'd work! It's probably the simplest fix, though we could also look at the call sites of flash_attention to see if some other organization would be more natural.
What's a repro for the behavior you're describing? I tried removing jax.checkpoint from the repro in the OP and I still got an error. That is, this still errors for me:
import jax
import jax.numpy as jnp
from flash_attention_jax import flash_attention
b = 3
lq = 16
lkv = 17
h = 5
d = 19
keys = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(keys[0], (b, lq, h, d))
k = jax.random.normal(keys[1], (b, lkv, h, d))
v = jax.random.normal(keys[2], (b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (b, lkv))
@jax.jit
def bench_flash_bwd(q, k, v, mask):
return jax.grad(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0]))(q)
bench_flash_bwd(q, k, v, mask)
Ah, I think it was just a shape bug; if I sent lq = lvk = 16 then I see what you mean.
I think by adding the better JAX error message I described, we'll catch this much earlier and get an error in both cases. I'll be sure to test both with and without checkpoint/scan.
Yes, that'd work!
Actually, I think it would not work just because the callers expect only a single output there.
I think the issue here was that the custom_vjp-decorated function (ie the "primal function") didn't agree with the custom_vjp rule (i.e. their output types didn't agree in the way that they should), but when we only use grad (possibly together with jit) we never actually run the primal function; we only run its forward rule. When grad is applied, we only actually run the primal function when under a jax.checkpoint or jax.scan (or jax.cond etc); that's just because of a JAX implementation detail (these are "initial-style higher-order primitives") which is usually invisible, except apparently when there's a type error in a custom_vjp rule!
with the fix it's working with lq = lkv under jax.checkpoint!
still fails with lq != lkv which I'm trying to debug now
https://github.com/midjourney/flash-attention-jax/commit/f690412199178bc60fb4a768f28bffb2f27654cb
the error with lq = 16; lkv = 17 is TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).
full backtrace:
TypeError Traceback (most recent call last)
Cell In [5], line 22
18 @jax.jit
19 def bench_flash_bwd(q, k, v, mask):
20 return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)
---> 22 bench_flash_bwd(q, k, v, mask)
[... skipping hidden 14 frame]
Cell In [5], line 20, in bench_flash_bwd(q, k, v, mask)
18 @jax.jit
19 def bench_flash_bwd(q, k, v, mask):
---> 20 return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)
[... skipping hidden 30 frame]
File ~/code/flash-attention-jax/flash_attention_jax/flash_attention.py:172, in flash_attention_backward(res, do)
169 dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
170 return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk
--> 172 (_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
174 dq = rearrange(dq, 'c n b h d -> b h (c n) d')
175 dk, dv = map(lambda t: rearrange(t, 'n b h d -> b h n d'), (dk, dv))
[... skipping hidden 11 frame]
File ~/code/flash-attention-jax/flash_attention_jax/flash_attention.py:170, in flash_attention_backward.<locals>.chunk_scanner(carries, _)
167 do_chunk = lax.dynamic_slice(do, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, do.shape[-1]))
169 dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
--> 170 return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk
[... skipping hidden 1 frame]
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4658, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
4656 args = (other, self) if swap else (self, other)
4657 if isinstance(other, _accepted_binop_types):
-> 4658 return binary_op(*args)
4659 if isinstance(other, _rejected_binop_types):
4660 raise TypeError(f"unsupported operand type(s) for {opchar}: "
4661 f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
[... skipping hidden 7 frame]
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py:84, in _maybe_bool_binop.<locals>.fn(x1, x2)
82 def fn(x1, x2):
83 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
---> 84 return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
[... skipping hidden 7 frame]
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/lax/lax.py:1537, in broadcasting_shape_rule(name, *avals)
1535 result_shape.append(non_1s[0])
1536 else:
-> 1537 raise TypeError(f'{name} got incompatible shapes for broadcasting: '
1538 f'{", ".join(map(str, map(tuple, shapes)))}.')
1540 return tuple(result_shape)
TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).
It looks like one of chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk has a shape error, in flash_attention_backward. (EDIT: I don't feel comfortable debugging that without learning what this code is actually doing, so hopefully someone who knows the code/algorithm can help!)
debugging a bit, it looks like the issue is that dk has shape h, b, lkv, d and dk_chunk has shape h, b, lq, d
@lucidrains looks like there's an implicit assumption somewhere in here that lq == lkv in the backwards pass, in _query_chunk_flash_attention_backward
@GallagherCommaJack the fix I proposed in #8 is different from the commit you sent, just FYI.
does that work with lq != lkv?
looks like it does not
Indeed I think the shape issue is unrelated.
google/jax#12611 should improve the error message we got here! With the same repro (i.e. before the fix #7 was merged here), the error will be:
TypeError: Custom VJP fwd rule flash_attention_forward for function
flash_attention must produce a pair (list or tuple of length two) where the
first element represents the primal output (equal to the output of the
custom_vjp-decorated function flash_attention) and the second element
represents residuals (i.e. values stored from the forward pass for use on the
backward pass), but instead the fwd rule output's first element had
container/pytree structure:
float32[3,16,5,19]
while the custom_vjp-decorated function flash_attention had output
container/pytree structure:
(float32[3,16,5,19], (float32[3,16,5], float32[3,16,5])).