cudnn-frontend icon indicating copy to clipboard operation
cudnn-frontend copied to clipboard

[Question] Making dO contiguous affects output?

Open vedantroy opened this issue 6 months ago • 2 comments

I've noticed when using Pytorch's custom autograd functions, that sometimes the stride of dO can be (0, 0, 0, 0). Here's a very simple example: https://discuss.pytorch.org/t/getting-unusual-strides-when-using-pytorchs-autograd/208093.

In my custom wrapper for CudNN, I solve this my making dO contiguous if the stride is all zeros. Code (ctrl-f for "CHECK FOR WEIRD STRIDE"):

import cudnn
import torch
import math

def convert_to_cudnn_type(torch_type):
    if torch_type == torch.float16:
        return cudnn.data_type.HALF
    elif torch_type == torch.bfloat16:
        return cudnn.data_type.BFLOAT16
    elif torch_type == torch.float32:
        return cudnn.data_type.FLOAT
    elif torch_type == torch.int32:
        return cudnn.data_type.INT32
    elif torch_type == torch.int64:
        return cudnn.data_type.INT64
    else:
        raise ValueError("Unsupported tensor data type.")

def make_cudnn_autograd(*, num_heads, head_dim, dtype):
    assert dtype in [torch.float16, torch.bfloat16], f"Invalid dtype {dtype}"
    dtype = convert_to_cudnn_type(dtype)
    # match CuDNN's docs
    H, D = num_heads, head_dim
    del num_heads, head_dim

    cache = {}

    def assert_cudnn_shape(tensor, expected_shape):
        assert tuple(tensor.get_dim()) == expected_shape, f"Expected shape {expected_shape} but got {tensor.get_dim()}"

    def init_or_check_tensor_attrs(tensor_name, tensor):
        nonlocal cache
        for attr in ['shape', 'stride', 'dtype', 'device']:
            key = f'{tensor_name}_{attr}'
            if key not in cache:
                cache[key] = getattr(tensor, attr)
                if callable(cache[key]):
                    cache[key] = cache[key]()
            else:
                v = cache[key]() if callable(cache[key]) else cache[key]
                assert cache[key] == v, f"Expected {cache[key]} but got {v}"


    class CuDNNAttention(torch.autograd.Function):
        @staticmethod
        def forward(ctx, B, N, L, q, kv, seqlens_kv):
            assert q.shape == (B, N, H, D)
            assert kv.shape == (B, N + L, 2, H, D)
            assert seqlens_kv.shape == (B,)

            # CuDNN plans are compiled for a specific shape, stride, dtype
            # So we need to verify those attributes
            init_or_check_tensor_attrs('q', q)
            init_or_check_tensor_attrs('kv', kv)
            init_or_check_tensor_attrs('seqlens_kv', seqlens_kv)

            q = q.permute(0, 2, 1, 3)  # B N H D -> B H N D
            kv_view = kv.permute(2, 0, 3, 1, 4) # B S KV H D -> KV B H S D
            k_view, v_view = torch.unbind(kv_view, dim=0)

            assert not k_view.is_contiguous() and not v_view.is_contiguous(), f"kv should not be contiguous (unnecessary copy)"
            assert k_view.shape == (B, H, (N + L),  D), f"Got shape {k_view.shape} instead of {(B, num_heads, (N + L),  D)}"
            assert v_view.shape == (B, H, (N + L), D)

            # TODO: Is this safe?
            if 'stats' not in cache:
                cache['stats'] = torch.empty(B, H, N, 1, dtype=torch.float32, device=q.device)
                cache['seqlens_q'] = torch.tensor([N] * B, device=q.device, dtype=torch.int32).view(B, 1, 1, 1)
                cache['o'] = torch.empty_like(q)

            stats = cache['stats']
            seqlens_q = cache['seqlens_q']
            o = cache['o']

            seqlens_kv = seqlens_kv.view(B, 1, 1, 1)

            if 'compiled_graph_fwd' not in cache:
                print("Compiling CuDNN forward graph ...")
                g_fwd = cudnn.pygraph(
                    io_data_type=dtype,
                    intermediate_data_type=cudnn.data_type.FLOAT,
                    compute_data_type=cudnn.data_type.FLOAT,
                )
                cache['name_to_cu_tensor'] = {
                    'q_cu': g_fwd.tensor_like(q.detach()),
                    'k_cu': g_fwd.tensor_like(k_view.detach()),
                    'v_cu': g_fwd.tensor_like(v_view.detach()),
                    'seqlens_q_cu': g_fwd.tensor_like(seqlens_q.detach()),
                    'seqlens_kv_cu': g_fwd.tensor_like(seqlens_kv.detach())
                }
                cu_tens = cache['name_to_cu_tensor']

                o_forward, stats_forward = g_fwd.sdpa(
                    name="sdpa",
                    q=cu_tens['q_cu'],
                    k=cu_tens['k_cu'],
                    v=cu_tens['v_cu'],
                    is_inference=False,
                    attn_scale=1.0 / math.sqrt(D),
                    use_causal_mask=False,
                    use_padding_mask=True,
                    seq_len_q=cu_tens['seqlens_q_cu'],
                    seq_len_kv=cu_tens['seqlens_kv_cu']
                )

                o_forward.set_output(True).set_dim(o.shape).set_stride(o.stride()).set_data_type(dtype)
                stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_dim(stats.shape).set_stride(stats.stride())

                cu_tens['o_forward_cu'] = o_forward
                cu_tens['stats_forward_cu'] = stats_forward

                assert_cudnn_shape(cu_tens['q_cu'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['k_cu'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['v_cu'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['o_forward_cu'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['stats_forward_cu'], (B, H, N, 1))
                assert_cudnn_shape(cu_tens['seqlens_q_cu'], (B, 1, 1, 1))
                assert_cudnn_shape(cu_tens['seqlens_kv_cu'], (B, 1, 1, 1))

                g_fwd.validate()
                g_fwd.build_operation_graph()
                g_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
                g_fwd.check_support()
                g_fwd.build_plans()

                cache['compiled_graph_fwd'] = g_fwd

                # TODO: Is this safe?
                cache['workspace'] = torch.empty(
                    g_fwd.get_workspace_size(),
                    device=q.device, dtype=torch.uint8
                )
            
            name_to_cu_tensor = cache['name_to_cu_tensor']
            variant_pack_forward = {
                name_to_cu_tensor[name]: tensor for name, tensor in [
                    ('q_cu', q),
                    ('k_cu', k_view),
                    ('v_cu', v_view),
                    ('o_forward_cu', o),
                    ('stats_forward_cu', stats),
                    ('seqlens_q_cu', seqlens_q),
                    ('seqlens_kv_cu', seqlens_kv)
                ]
            }
            cache['compiled_graph_fwd'].execute(variant_pack_forward, cache['workspace'])
            ctx.save_for_backward(q, k_view, v_view, o, stats, seqlens_kv)
            ctx.B, ctx.N, ctx.L = B, N, L
            ctx.dtype = dtype
            return o

        @staticmethod
        def backward(ctx, dO):
            q, k_view, v_view, o, stats, seqlens_kv = ctx.saved_tensors
            B, N, L = ctx.B, ctx.N, ctx.L
            seqlens_q = cache['seqlens_q']
            cu_tens = cache['name_to_cu_tensor']

            init_or_check_tensor_attrs('dO', dO)

            # CHECK FOR WEIRD STRIDE
            # if dO's total stride is 0, copy it to a single element tensor
            if all(s == 0 for s in dO.stride()):
                dO = dO.contiguous()
            assert dO.shape == (B, H, N, D)
            # dO = dO.contiguous()

            if 'dQ' not in cache:
                cache['dQ'] = torch.empty_like(q)
                cache['dK'] = torch.empty_like(k_view)
                cache['dV'] = torch.empty_like(v_view)

            dQ, dK, dV = cache['dQ'], cache['dK'], cache['dV']

            if 'compiled_graph_bwd' not in cache:
                print(f"Compiling CuDNN backward graph ...")
                g_bwd = cudnn.pygraph(
                     io_data_type=dtype,
                     intermediate_data_type=cudnn.data_type.FLOAT,
                     compute_data_type=cudnn.data_type.FLOAT,
                )

                cu_tens['q_cu_bwd'] = g_bwd.tensor_like(q.detach())
                cu_tens['k_cu_bwd'] = g_bwd.tensor_like(k_view.detach())
                cu_tens['v_cu_bwd'] = g_bwd.tensor_like(v_view.detach())
                cu_tens['o_cu_bwd'] = g_bwd.tensor_like(o.detach())
                cu_tens['dO_cu_bwd'] = g_bwd.tensor_like(dO.detach())
                cu_tens['stats_cu_bwd'] = g_bwd.tensor_like(stats.detach())
                cu_tens['seqlens_q_cu_bwd'] = g_bwd.tensor_like(seqlens_q.detach())
                cu_tens['seqlens_kv_cu_bwd'] = g_bwd.tensor_like(seqlens_kv.detach())

                dQ_bwd_cu, dK_bwd_cu, dV_bwd_cu = g_bwd.sdpa_backward(
                    name="sdpa_backward",
                    q=cu_tens['q_cu_bwd'],
                    k=cu_tens['k_cu_bwd'],
                    v=cu_tens['v_cu_bwd'],
                    o=cu_tens['o_cu_bwd'],
                    dO=cu_tens['dO_cu_bwd'],
                    stats=cu_tens['stats_cu_bwd'],
                    attn_scale=1.0 / math.sqrt(D),
                    use_causal_mask=False,
                    use_padding_mask=True,
                    seq_len_q=cu_tens['seqlens_q_cu_bwd'],
                    seq_len_kv=cu_tens['seqlens_kv_cu_bwd']
                )

                dQ_bwd_cu.set_output(True).set_dim(dQ.size()).set_stride(dQ.stride())
                dK_bwd_cu.set_output(True).set_dim(dK.size()).set_stride(dK.stride())
                dV_bwd_cu.set_output(True).set_dim(dV.size()).set_stride(dV.stride())

                cu_tens['dQ_cu_bwd'] = dQ_bwd_cu
                cu_tens['dK_cu_bwd'] = dK_bwd_cu
                cu_tens['dV_cu_bwd'] = dV_bwd_cu

                assert_cudnn_shape(cu_tens['q_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['k_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['v_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['dQ_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['dK_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['dV_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['o_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['dO_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['stats_cu_bwd'], (B, H, N, 1))
                assert_cudnn_shape(cu_tens['seqlens_q_cu_bwd'], (B, 1, 1, 1))
                assert_cudnn_shape(cu_tens['seqlens_kv_cu_bwd'], (B, 1, 1, 1))

                g_bwd.validate()
                g_bwd.build_operation_graph()
                g_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
                g_bwd.check_support()
                g_bwd.build_plans()

                cache['compiled_graph_bwd'] = g_bwd
                cache['workspace'] = torch.empty(
                    max(cache['compiled_graph_fwd'].get_workspace_size(), cache['compiled_graph_bwd'].get_workspace_size()),
                    device=q.device, dtype=torch.uint8
                )

            variant_pack_backward = {
                cu_tens[name]: tensor for name, tensor in [
                    ('dQ_cu_bwd', cache['dQ']),
                    ('dK_cu_bwd', cache['dK']),
                    ('dV_cu_bwd', cache['dV']),
                    ('q_cu_bwd', q),
                    ('k_cu_bwd', k_view),
                    ('v_cu_bwd', v_view),
                    ('o_cu_bwd', o),
                    ('dO_cu_bwd', dO),
                    ('stats_cu_bwd', stats),
                    ('seqlens_q_cu_bwd', seqlens_q),
                    ('seqlens_kv_cu_bwd', seqlens_kv)
                ]
            }

            cache['compiled_graph_bwd'].execute(variant_pack_backward, cache['workspace'])
            assert cache['dQ'].shape == (B, H, N, D)
            dQ = cache['dQ'].permute(0, 2, 1, 3) # B H N D -> B N H D

            assert cache['dK'].shape == (B, H, N + L, D)
            assert cache['dV'].shape == (B, H, N + L, D)

            dKV = torch.stack([cache['dK'], cache['dV']], dim=2)
            assert dKV.shape == (B, H, 2, N + L, D)

            dKV = dKV.permute(0, 3, 2, 1, 4) # B H 2 N D -> B N 2 H D

            return None, None, None, dQ, dKV, None

    return CuDNNAttention

The problem is, when I do this, I get massive numerical error. Do you have thoughts on why making dO contiguous might cause issues?

vedantroy avatar Aug 14 '24 19:08 vedantroy