`jax.nn.dot_product_attention` does not respect `key_value_seq_lengths`
Description
Perhaps I am using this function incorrectly, but I get data leaks when using key_value_seq_lengths. It appears as though both the xla and cudnn implementations in jax nightly do not support this argument. Here is some reproducible code:
#!/usr/bin/env python3
import jax.numpy as jnp
from jax import random, nn
B, L, H, D = 8, 128, 4, 64
rng = random.key(42)
x = random.normal(rng, (B, L, H, D // H), dtype=jnp.bfloat16)
valid_lens = jnp.array([24, 125, 53, 28, 77, 96, 13, 114], jnp.int32)
def vanilla_attention(qs, ks, vs, valid_lens):
scores = jnp.einsum("BQHD,BKHD->BHQK", qs, ks) / jnp.sqrt(D // H)
if valid_lens is not None:
mask = jnp.arange(L) < valid_lens[:, None]
mask = mask[:, None, None, :] # broadcast across H, Q in [B, H, Q, K]
scores = jnp.where(mask, scores, -jnp.inf)
attn = nn.softmax(scores, axis=-1)
return jnp.einsum("BHQK,BKHD->BQHD", attn, vs).reshape(B, L, D)
def xla_attention(qs, ks, vs, valid_lens):
ctx = nn.dot_product_attention(
qs, ks, vs, key_value_seq_lengths=valid_lens, implementation="xla"
)
return ctx.reshape(B, L, D)
def cudnn_attention(qs, ks, vs, valid_lens):
ctx = nn.dot_product_attention(
qs, ks, vs, key_value_seq_lengths=valid_lens, implementation="cudnn"
)
return ctx.reshape(B, L, D)
van_attn = vanilla_attention(x, x, x, valid_lens)
xla_attn = xla_attention(x, x, x, valid_lens)
cud_attn = cudnn_attention(x, x, x, valid_lens)
print(jnp.allclose(van_attn, xla_attn, rtol=1.0, atol=1.0)) # False
print(jnp.allclose(van_attn, cud_attn, rtol=1.0, atol=1.0)) # False
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01)) # True
van_attn = vanilla_attention(x, x, x, None)
xla_attn = xla_attention(x, x, x, None)
cud_attn = cudnn_attention(x, x, x, None)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01)) # True
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.32.dev20240830
jaxlib: 0.4.31
numpy: 1.26.4
python: 3.12.2 (main, Mar 2 2024, 09:51:01) [GCC 13.2.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='ghost', release='6.6.47_1', version='#1 SMP PREEMPT_DYNAMIC Mon Aug 19 16:42:31 UTC 2024', machine='x86_64')
$ nvidia-smi
Fri Aug 30 18:11:31 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02 Driver Version: 550.107.02 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off |
| 32% 39C P2 78W / 480W | 393MiB / 24564MiB | 3% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 28449 C ...ions/3.12.2/envs/jax/bin/python3.12 386MiB |
+-----------------------------------------------------------------------------------------+
@kaixih PTAL.
I just created a PR to fix this issue. Basically, the current API requires both query_seq_lengths and key_value_seq_lengths. This PR relaxes it. Can you take a look at it to see if it works?
From user side, you can also try explicitly provide the query_seq_lengths with a tensor filled with max seq lengths.
The following works when I provide the max len for query lengths:
#!/usr/bin/env python3
import jax.numpy as jnp
from jax import random, nn
B, L, H, D = 8, 128, 4, 64
rng = random.key(42)
x = random.normal(rng, (B, L, H, D // H), dtype=jnp.bfloat16)
valid_lens = jnp.array([24, 125, 53, 28, 77, 96, 13, 114], jnp.int32)
def vanilla_attention(qs, ks, vs, valid_lens):
scores = jnp.einsum("BQHD,BKHD->BHQK", qs, ks) / jnp.sqrt(D // H)
if valid_lens is not None:
mask = jnp.arange(L) < valid_lens[:, None]
mask = mask[:, None, None, :] # broadcast across H, Q in [B, H, Q, K]
scores = jnp.where(mask, scores, -jnp.inf)
attn = nn.softmax(scores, axis=-1)
return jnp.einsum("BHQK,BKHD->BQHD", attn, vs).reshape(B, L, D)
def xla_attention(qs, ks, vs, valid_lens):
if valid_lens is None:
valid_lens = jnp.repeat(L, B)
ctx = nn.dot_product_attention(
qs,
ks,
vs,
query_seq_lengths=jnp.repeat(L, B),
key_value_seq_lengths=valid_lens,
implementation="xla",
)
return ctx.reshape(B, L, D)
def cudnn_attention(qs, ks, vs, valid_lens):
if valid_lens is None:
valid_lens = jnp.repeat(L, B)
ctx = nn.dot_product_attention(
qs,
ks,
vs,
query_seq_lengths=jnp.repeat(L, B),
key_value_seq_lengths=valid_lens,
implementation="cudnn",
)
return ctx.reshape(B, L, D)
van_attn = vanilla_attention(x, x, x, valid_lens)
xla_attn = xla_attention(x, x, x, valid_lens)
cud_attn = cudnn_attention(x, x, x, valid_lens)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01)) # True
van_attn = vanilla_attention(x, x, x, None)
xla_attn = xla_attention(x, x, x, None)
cud_attn = cudnn_attention(x, x, x, None)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01)) # True
@danjenson Can we know if it is a typical use case for you to only provide the kv_seq_lengths?
Constantly -- usually I want an answer to every "query" but each query can only use specific data/keys when answering that question.