eqy
eqy
For #138340 We might consider more sophisticated logic here but the corresponding logic in other backends doesn't seem to do anything fancy for non BSHD/BHSD cases https://github.com/pytorch/pytorch/blob/ea8ea2f33fc65b33dc562f4b0430f8c79eb81d8d/aten/src/ATen/native/transformers/cuda/attention.cu#L1145 cc @csarofeen @ptrblck...
`float16` tolerance was previously set to `1e-5` which seemed very low
Disabled by default for now behind `TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED=1` Just wanted to get this out before starting a series of SDPA cleanup PRs---the biggest thing is we don't need the boilerplate around...