cudnn-frontend icon indicating copy to clipboard operation
cudnn-frontend copied to clipboard

sdpa_fp8 having different seqlen_q and seqlen_k

Open MustafaFayez opened this issue 5 months ago • 0 comments

Hi I tried running a sdpa_fp8 graph where seqlen_q and seqlen_k are different, however it seems that it only uses the seqlen_q as in performance is the same when I only sweep seqlen_k, here is the func I wrote:

def cudnn_spda_setup(q, k, v, seqlen_q, seqlen_k, causal=False):
    b, _, nheads, headdim = q.shape
    assert cudnn is not None, 'CUDNN is not available'
    device = q.device
    o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=q.dtype, device=device)
    o_gpu_transposed = torch.as_strided(
        o_gpu,
        [b, nheads, seqlen_q, headdim],
        [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1],
    )
    stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=device)
    amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=device)
    amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=device)
    graph = cudnn.pygraph(
        io_data_type=convert_to_cudnn_type(q.dtype),
        intermediate_data_type=cudnn.data_type.FLOAT,
        compute_data_type=cudnn.data_type.FLOAT,
    )
    shape_q = (b, nheads, seqlen_q, headdim)
    shape_k = (b, nheads, seqlen_k, headdim)
    shape_v = (b, nheads, seqlen_k, headdim)
    shape_o = (b, nheads, seqlen_q, headdim)
    qkv_num_elems = math.prod(shape_q) + math.prod(shape_k) + math.prod(shape_v)
    (stride_q, stride_k, stride_v, stride_o, offset_q, offset_k, offset_v) = generate_layout(
        shape_q,
        shape_k)
    qkv = torch.randn(qkv_num_elems, dtype=torch.float16, device="cuda")
    qkv_gpu = qkv.to(q.dtype)
    q_gpu = torch.as_strided(qkv_gpu, shape_q, stride_q, storage_offset=offset_q)
    k_gpu = torch.as_strided(qkv_gpu, shape_k, stride_k, storage_offset=offset_k)
    v_gpu = torch.as_strided(qkv_gpu, shape_v, stride_v, storage_offset=offset_v)
    q = graph.tensor(name = "Q",
        dim = list(q_gpu.shape),
        stride = list(q_gpu.stride()),
        data_type=convert_to_cudnn_type(qkv_gpu.dtype)
    )
    k = graph.tensor(name = "K",
        dim = list(k_gpu.shape),
        stride = list(k_gpu.stride()),
        data_type=convert_to_cudnn_type(qkv_gpu.dtype)
    )
    v = graph.tensor(name = "V",
        dim = list(v_gpu.shape),
        stride = list(v_gpu.stride()),
        data_type=convert_to_cudnn_type(qkv_gpu.dtype)
    )
    def get_default_scale_tensor():
        return graph.tensor(
            dim = [1, 1, 1, 1],
            stride = [1, 1, 1, 1],
            data_type=cudnn.data_type.FLOAT
        )

    default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda")
    descale_q = get_default_scale_tensor()
    descale_k = get_default_scale_tensor()
    descale_v = get_default_scale_tensor()
    descale_s = get_default_scale_tensor()
    scale_s = get_default_scale_tensor()
    scale_o = get_default_scale_tensor()

    o, _, amax_s, amax_o = graph.sdpa_fp8(
        q=q,
        k=k,
        v=v,
        descale_q=descale_q,
        descale_k=descale_k,
        descale_v=descale_v,
        descale_s=descale_s,
        scale_s=scale_s,
        scale_o=scale_o,
        is_inference=True,
        attn_scale=1.0 / math.sqrt(headdim),
        use_causal_mask=causal,
        name="sdpa",
    )

    o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride())

    amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride())
    amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride())
    # stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)

    graph.validate()
    graph.build_operation_graph()
    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
    graph.check_support()
    graph.build_plans()

    variant_pack = {
        q: q_gpu,
        k: k_gpu,
        v: v_gpu,
        descale_q: default_scale_gpu,
        descale_k: default_scale_gpu,
        descale_v: default_scale_gpu,
        descale_s: default_scale_gpu,
        scale_s: default_scale_gpu,
        scale_o: default_scale_gpu,
        o: o_gpu_transposed,
        amax_s: amax_s_gpu,
        amax_o: amax_o_gpu,
    }

    workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)

    def run(*args, **kwargs):
        graph.execute(variant_pack, workspace)
        return o_gpu, amax_o_gpu

    return run

what am I doing wrong? Thanks.

MustafaFayez avatar Sep 21 '24 06:09 MustafaFayez