lightning-thunder
lightning-thunder copied to clipboard
Makes cudnn a default executor
Before submitting
- [x] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?
What does this PR do?
Cudnn is now a default executor. The only operation targeted is sdpa.
The main change is a stricter checker function. Both forward and backward graph support are ensured before claiming sdpa operation. (The checker was previously made lenient in #57)
Fixes #418.
PR review
Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃
Before:
---------------------------------------------------------------------------------------------------- benchmark: 5 tests ---------------------------------------------------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_llama2_7b_sdpa_grad[thunder+cudnn] 16.1154 (1.0) 17.0821 (1.0) 16.4272 (1.0) 0.2377 (1.0) 16.4072 (1.0) 0.3396 (1.56) 9;1 60.8745 (1.0) 40 1
test_llama2_7b_sdpa_grad[thunder] 27.0254 (1.68) 28.7664 (1.68) 27.5484 (1.68) 0.4476 (1.88) 27.3732 (1.67) 0.3508 (1.61) 12;4 36.2998 (0.60) 40 1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
After:
---------------------------------------------------------------------------------------------------- benchmark: 5 tests ---------------------------------------------------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_llama2_7b_sdpa_grad[thunder+cudnn] 16.1055 (1.0) 16.8787 (1.0) 16.4041 (1.0) 0.2354 (1.0) 16.4000 (1.0) 0.4397 (1.05) 18;0 60.9606 (1.0) 40 1
test_llama2_7b_sdpa_grad[thunder] 16.1180 (1.00) 17.0175 (1.01) 16.4810 (1.00) 0.2475 (1.05) 16.4507 (1.00) 0.4206 (1.0) 17;0 60.6759 (1.00) 40 1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Thank you!
@parthmannan how do you feel about this from perf perspective? one concern might be the
_make_cudnn_sdpa_*_graphexpense, but maybe that's 1) cheaper than i worry about, or 2) not that relevant because_cudnn_sdpa_checkerdoesn't get called all that often anyway ?
@tfogal I don't think we need to worry about this from a performance perspective as you pointed in 2. - I don't expect this to be called very often. Hopefully just the first iteration for static shapes if I am thinking about this correctly.
What happens with dynamic shapes? We eventually plan to support that, can cuDNN create broader graphs that work for many shapes or will we need to call this everytime?
@t-vi this is ready for your final review and merge. :)
@vedaanta I think the test failures are real / caused by this patch:
jfn = thunder.jit(module)
> result = jfn(*args, **kwargs)
thunder/tests/test_jit_general.py:608:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541: in _call_impl
return forward_call(*args, **kwargs)
thunder/core/module.py:49: in forward
res = self._forward_fn(*args, **kwargs)
thunder/__init__.py:626: in fn_
result = cache_entry.computation_fn(*inps)
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115: in decorate_context
return func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:28: in decorate_autocast
return func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:28: in decorate_autocast
return func(*args, **kwargs)
thunder.computation_729:68: in computation
(y, _, _, _) = cudnn_sdpa_fwd(q, t74, t78, None, 0, True, scale=None)
thunder/executors/cudnnex.py:357: in _cudnn_sdpa_fwd_impl
with torch.cuda.device(query.device):
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:361: in __init__
self.idx = _get_device_index(device, optional=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
device = device(type='cpu'), optional = True, allow_cpu = False
def _get_device_index(
device: Any, optional: bool = False, allow_cpu: bool = False
) -> int:
r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
If :attr:`device` is a torch.device object, returns the device index if it
is a CUDA device. Note that for a CUDA device without a specified index,
this is ready for your final review and merge. :)
I think we should hold off pending investigation; please take a look and let's double check we're not causing regressions here.
Yeah, as @tfogal , points out, I think something is up in the tests.