cudnn-frontend
cudnn-frontend copied to clipboard
sdpa_fp8 having different seqlen_q and seqlen_k
Hi I tried running a sdpa_fp8 graph where seqlen_q and seqlen_k are different, however it seems that it only uses the seqlen_q as in performance is the same when I only sweep seqlen_k, here is the func I wrote:
def cudnn_spda_setup(q, k, v, seqlen_q, seqlen_k, causal=False):
b, _, nheads, headdim = q.shape
assert cudnn is not None, 'CUDNN is not available'
device = q.device
o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=q.dtype, device=device)
o_gpu_transposed = torch.as_strided(
o_gpu,
[b, nheads, seqlen_q, headdim],
[nheads * seqlen_q * headdim, headdim, nheads * headdim, 1],
)
stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=device)
amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=device)
amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=device)
graph = cudnn.pygraph(
io_data_type=convert_to_cudnn_type(q.dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
shape_q = (b, nheads, seqlen_q, headdim)
shape_k = (b, nheads, seqlen_k, headdim)
shape_v = (b, nheads, seqlen_k, headdim)
shape_o = (b, nheads, seqlen_q, headdim)
qkv_num_elems = math.prod(shape_q) + math.prod(shape_k) + math.prod(shape_v)
(stride_q, stride_k, stride_v, stride_o, offset_q, offset_k, offset_v) = generate_layout(
shape_q,
shape_k)
qkv = torch.randn(qkv_num_elems, dtype=torch.float16, device="cuda")
qkv_gpu = qkv.to(q.dtype)
q_gpu = torch.as_strided(qkv_gpu, shape_q, stride_q, storage_offset=offset_q)
k_gpu = torch.as_strided(qkv_gpu, shape_k, stride_k, storage_offset=offset_k)
v_gpu = torch.as_strided(qkv_gpu, shape_v, stride_v, storage_offset=offset_v)
q = graph.tensor(name = "Q",
dim = list(q_gpu.shape),
stride = list(q_gpu.stride()),
data_type=convert_to_cudnn_type(qkv_gpu.dtype)
)
k = graph.tensor(name = "K",
dim = list(k_gpu.shape),
stride = list(k_gpu.stride()),
data_type=convert_to_cudnn_type(qkv_gpu.dtype)
)
v = graph.tensor(name = "V",
dim = list(v_gpu.shape),
stride = list(v_gpu.stride()),
data_type=convert_to_cudnn_type(qkv_gpu.dtype)
)
def get_default_scale_tensor():
return graph.tensor(
dim = [1, 1, 1, 1],
stride = [1, 1, 1, 1],
data_type=cudnn.data_type.FLOAT
)
default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda")
descale_q = get_default_scale_tensor()
descale_k = get_default_scale_tensor()
descale_v = get_default_scale_tensor()
descale_s = get_default_scale_tensor()
scale_s = get_default_scale_tensor()
scale_o = get_default_scale_tensor()
o, _, amax_s, amax_o = graph.sdpa_fp8(
q=q,
k=k,
v=v,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_s=descale_s,
scale_s=scale_s,
scale_o=scale_o,
is_inference=True,
attn_scale=1.0 / math.sqrt(headdim),
use_causal_mask=causal,
name="sdpa",
)
o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride())
amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride())
amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride())
# stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph.check_support()
graph.build_plans()
variant_pack = {
q: q_gpu,
k: k_gpu,
v: v_gpu,
descale_q: default_scale_gpu,
descale_k: default_scale_gpu,
descale_v: default_scale_gpu,
descale_s: default_scale_gpu,
scale_s: default_scale_gpu,
scale_o: default_scale_gpu,
o: o_gpu_transposed,
amax_s: amax_s_gpu,
amax_o: amax_o_gpu,
}
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
def run(*args, **kwargs):
graph.execute(variant_pack, workspace)
return o_gpu, amax_o_gpu
return run
what am I doing wrong? Thanks.