cudnn-frontend
cudnn-frontend copied to clipboard
[Question] Making dO contiguous affects output?
I've noticed when using Pytorch's custom autograd functions, that sometimes the stride of dO
can be (0, 0, 0, 0)
.
Here's a very simple example: https://discuss.pytorch.org/t/getting-unusual-strides-when-using-pytorchs-autograd/208093.
In my custom wrapper for CudNN, I solve this my making dO
contiguous if the stride is all zeros. Code (ctrl-f for "CHECK FOR WEIRD STRIDE"):
import cudnn
import torch
import math
def convert_to_cudnn_type(torch_type):
if torch_type == torch.float16:
return cudnn.data_type.HALF
elif torch_type == torch.bfloat16:
return cudnn.data_type.BFLOAT16
elif torch_type == torch.float32:
return cudnn.data_type.FLOAT
elif torch_type == torch.int32:
return cudnn.data_type.INT32
elif torch_type == torch.int64:
return cudnn.data_type.INT64
else:
raise ValueError("Unsupported tensor data type.")
def make_cudnn_autograd(*, num_heads, head_dim, dtype):
assert dtype in [torch.float16, torch.bfloat16], f"Invalid dtype {dtype}"
dtype = convert_to_cudnn_type(dtype)
# match CuDNN's docs
H, D = num_heads, head_dim
del num_heads, head_dim
cache = {}
def assert_cudnn_shape(tensor, expected_shape):
assert tuple(tensor.get_dim()) == expected_shape, f"Expected shape {expected_shape} but got {tensor.get_dim()}"
def init_or_check_tensor_attrs(tensor_name, tensor):
nonlocal cache
for attr in ['shape', 'stride', 'dtype', 'device']:
key = f'{tensor_name}_{attr}'
if key not in cache:
cache[key] = getattr(tensor, attr)
if callable(cache[key]):
cache[key] = cache[key]()
else:
v = cache[key]() if callable(cache[key]) else cache[key]
assert cache[key] == v, f"Expected {cache[key]} but got {v}"
class CuDNNAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, B, N, L, q, kv, seqlens_kv):
assert q.shape == (B, N, H, D)
assert kv.shape == (B, N + L, 2, H, D)
assert seqlens_kv.shape == (B,)
# CuDNN plans are compiled for a specific shape, stride, dtype
# So we need to verify those attributes
init_or_check_tensor_attrs('q', q)
init_or_check_tensor_attrs('kv', kv)
init_or_check_tensor_attrs('seqlens_kv', seqlens_kv)
q = q.permute(0, 2, 1, 3) # B N H D -> B H N D
kv_view = kv.permute(2, 0, 3, 1, 4) # B S KV H D -> KV B H S D
k_view, v_view = torch.unbind(kv_view, dim=0)
assert not k_view.is_contiguous() and not v_view.is_contiguous(), f"kv should not be contiguous (unnecessary copy)"
assert k_view.shape == (B, H, (N + L), D), f"Got shape {k_view.shape} instead of {(B, num_heads, (N + L), D)}"
assert v_view.shape == (B, H, (N + L), D)
# TODO: Is this safe?
if 'stats' not in cache:
cache['stats'] = torch.empty(B, H, N, 1, dtype=torch.float32, device=q.device)
cache['seqlens_q'] = torch.tensor([N] * B, device=q.device, dtype=torch.int32).view(B, 1, 1, 1)
cache['o'] = torch.empty_like(q)
stats = cache['stats']
seqlens_q = cache['seqlens_q']
o = cache['o']
seqlens_kv = seqlens_kv.view(B, 1, 1, 1)
if 'compiled_graph_fwd' not in cache:
print("Compiling CuDNN forward graph ...")
g_fwd = cudnn.pygraph(
io_data_type=dtype,
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
cache['name_to_cu_tensor'] = {
'q_cu': g_fwd.tensor_like(q.detach()),
'k_cu': g_fwd.tensor_like(k_view.detach()),
'v_cu': g_fwd.tensor_like(v_view.detach()),
'seqlens_q_cu': g_fwd.tensor_like(seqlens_q.detach()),
'seqlens_kv_cu': g_fwd.tensor_like(seqlens_kv.detach())
}
cu_tens = cache['name_to_cu_tensor']
o_forward, stats_forward = g_fwd.sdpa(
name="sdpa",
q=cu_tens['q_cu'],
k=cu_tens['k_cu'],
v=cu_tens['v_cu'],
is_inference=False,
attn_scale=1.0 / math.sqrt(D),
use_causal_mask=False,
use_padding_mask=True,
seq_len_q=cu_tens['seqlens_q_cu'],
seq_len_kv=cu_tens['seqlens_kv_cu']
)
o_forward.set_output(True).set_dim(o.shape).set_stride(o.stride()).set_data_type(dtype)
stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_dim(stats.shape).set_stride(stats.stride())
cu_tens['o_forward_cu'] = o_forward
cu_tens['stats_forward_cu'] = stats_forward
assert_cudnn_shape(cu_tens['q_cu'], (B, H, N, D))
assert_cudnn_shape(cu_tens['k_cu'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['v_cu'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['o_forward_cu'], (B, H, N, D))
assert_cudnn_shape(cu_tens['stats_forward_cu'], (B, H, N, 1))
assert_cudnn_shape(cu_tens['seqlens_q_cu'], (B, 1, 1, 1))
assert_cudnn_shape(cu_tens['seqlens_kv_cu'], (B, 1, 1, 1))
g_fwd.validate()
g_fwd.build_operation_graph()
g_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
g_fwd.check_support()
g_fwd.build_plans()
cache['compiled_graph_fwd'] = g_fwd
# TODO: Is this safe?
cache['workspace'] = torch.empty(
g_fwd.get_workspace_size(),
device=q.device, dtype=torch.uint8
)
name_to_cu_tensor = cache['name_to_cu_tensor']
variant_pack_forward = {
name_to_cu_tensor[name]: tensor for name, tensor in [
('q_cu', q),
('k_cu', k_view),
('v_cu', v_view),
('o_forward_cu', o),
('stats_forward_cu', stats),
('seqlens_q_cu', seqlens_q),
('seqlens_kv_cu', seqlens_kv)
]
}
cache['compiled_graph_fwd'].execute(variant_pack_forward, cache['workspace'])
ctx.save_for_backward(q, k_view, v_view, o, stats, seqlens_kv)
ctx.B, ctx.N, ctx.L = B, N, L
ctx.dtype = dtype
return o
@staticmethod
def backward(ctx, dO):
q, k_view, v_view, o, stats, seqlens_kv = ctx.saved_tensors
B, N, L = ctx.B, ctx.N, ctx.L
seqlens_q = cache['seqlens_q']
cu_tens = cache['name_to_cu_tensor']
init_or_check_tensor_attrs('dO', dO)
# CHECK FOR WEIRD STRIDE
# if dO's total stride is 0, copy it to a single element tensor
if all(s == 0 for s in dO.stride()):
dO = dO.contiguous()
assert dO.shape == (B, H, N, D)
# dO = dO.contiguous()
if 'dQ' not in cache:
cache['dQ'] = torch.empty_like(q)
cache['dK'] = torch.empty_like(k_view)
cache['dV'] = torch.empty_like(v_view)
dQ, dK, dV = cache['dQ'], cache['dK'], cache['dV']
if 'compiled_graph_bwd' not in cache:
print(f"Compiling CuDNN backward graph ...")
g_bwd = cudnn.pygraph(
io_data_type=dtype,
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
cu_tens['q_cu_bwd'] = g_bwd.tensor_like(q.detach())
cu_tens['k_cu_bwd'] = g_bwd.tensor_like(k_view.detach())
cu_tens['v_cu_bwd'] = g_bwd.tensor_like(v_view.detach())
cu_tens['o_cu_bwd'] = g_bwd.tensor_like(o.detach())
cu_tens['dO_cu_bwd'] = g_bwd.tensor_like(dO.detach())
cu_tens['stats_cu_bwd'] = g_bwd.tensor_like(stats.detach())
cu_tens['seqlens_q_cu_bwd'] = g_bwd.tensor_like(seqlens_q.detach())
cu_tens['seqlens_kv_cu_bwd'] = g_bwd.tensor_like(seqlens_kv.detach())
dQ_bwd_cu, dK_bwd_cu, dV_bwd_cu = g_bwd.sdpa_backward(
name="sdpa_backward",
q=cu_tens['q_cu_bwd'],
k=cu_tens['k_cu_bwd'],
v=cu_tens['v_cu_bwd'],
o=cu_tens['o_cu_bwd'],
dO=cu_tens['dO_cu_bwd'],
stats=cu_tens['stats_cu_bwd'],
attn_scale=1.0 / math.sqrt(D),
use_causal_mask=False,
use_padding_mask=True,
seq_len_q=cu_tens['seqlens_q_cu_bwd'],
seq_len_kv=cu_tens['seqlens_kv_cu_bwd']
)
dQ_bwd_cu.set_output(True).set_dim(dQ.size()).set_stride(dQ.stride())
dK_bwd_cu.set_output(True).set_dim(dK.size()).set_stride(dK.stride())
dV_bwd_cu.set_output(True).set_dim(dV.size()).set_stride(dV.stride())
cu_tens['dQ_cu_bwd'] = dQ_bwd_cu
cu_tens['dK_cu_bwd'] = dK_bwd_cu
cu_tens['dV_cu_bwd'] = dV_bwd_cu
assert_cudnn_shape(cu_tens['q_cu_bwd'], (B, H, N, D))
assert_cudnn_shape(cu_tens['k_cu_bwd'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['v_cu_bwd'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['dQ_cu_bwd'], (B, H, N, D))
assert_cudnn_shape(cu_tens['dK_cu_bwd'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['dV_cu_bwd'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['o_cu_bwd'], (B, H, N, D))
assert_cudnn_shape(cu_tens['dO_cu_bwd'], (B, H, N, D))
assert_cudnn_shape(cu_tens['stats_cu_bwd'], (B, H, N, 1))
assert_cudnn_shape(cu_tens['seqlens_q_cu_bwd'], (B, 1, 1, 1))
assert_cudnn_shape(cu_tens['seqlens_kv_cu_bwd'], (B, 1, 1, 1))
g_bwd.validate()
g_bwd.build_operation_graph()
g_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
g_bwd.check_support()
g_bwd.build_plans()
cache['compiled_graph_bwd'] = g_bwd
cache['workspace'] = torch.empty(
max(cache['compiled_graph_fwd'].get_workspace_size(), cache['compiled_graph_bwd'].get_workspace_size()),
device=q.device, dtype=torch.uint8
)
variant_pack_backward = {
cu_tens[name]: tensor for name, tensor in [
('dQ_cu_bwd', cache['dQ']),
('dK_cu_bwd', cache['dK']),
('dV_cu_bwd', cache['dV']),
('q_cu_bwd', q),
('k_cu_bwd', k_view),
('v_cu_bwd', v_view),
('o_cu_bwd', o),
('dO_cu_bwd', dO),
('stats_cu_bwd', stats),
('seqlens_q_cu_bwd', seqlens_q),
('seqlens_kv_cu_bwd', seqlens_kv)
]
}
cache['compiled_graph_bwd'].execute(variant_pack_backward, cache['workspace'])
assert cache['dQ'].shape == (B, H, N, D)
dQ = cache['dQ'].permute(0, 2, 1, 3) # B H N D -> B N H D
assert cache['dK'].shape == (B, H, N + L, D)
assert cache['dV'].shape == (B, H, N + L, D)
dKV = torch.stack([cache['dK'], cache['dV']], dim=2)
assert dKV.shape == (B, H, 2, N + L, D)
dKV = dKV.permute(0, 3, 2, 1, 4) # B H 2 N D -> B N 2 H D
return None, None, None, dQ, dKV, None
return CuDNNAttention
The problem is, when I do this, I get massive numerical error. Do you have thoughts on why making dO
contiguous might cause issues?