lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Makes cudnn a default executor

Open vedaanta opened this issue 1 year ago • 4 comments

Before submitting
  • [x] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • [x] Did you read the contributor guideline, Pull Request section?
  • [x] Did you make sure to update the docs?
  • [x] Did you write any new necessary tests?

What does this PR do?

Cudnn is now a default executor. The only operation targeted is sdpa.

The main change is a stricter checker function. Both forward and backward graph support are ensured before claiming sdpa operation. (The checker was previously made lenient in #57)

Fixes #418.

PR review

Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

vedaanta avatar May 16 '24 18:05 vedaanta

Before:

---------------------------------------------------------------------------------------------------- benchmark: 5 tests ---------------------------------------------------------------------------------------------------
Name (time in ms)                                               Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_llama2_7b_sdpa_grad[thunder+cudnn]                     16.1154 (1.0)      17.0821 (1.0)      16.4272 (1.0)      0.2377 (1.0)      16.4072 (1.0)      0.3396 (1.56)          9;1  60.8745 (1.0)          40           1
test_llama2_7b_sdpa_grad[thunder]                           27.0254 (1.68)     28.7664 (1.68)     27.5484 (1.68)     0.4476 (1.88)     27.3732 (1.67)     0.3508 (1.61)         12;4  36.2998 (0.60)         40           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After:

---------------------------------------------------------------------------------------------------- benchmark: 5 tests ---------------------------------------------------------------------------------------------------
Name (time in ms)                                               Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_llama2_7b_sdpa_grad[thunder+cudnn]                     16.1055 (1.0)      16.8787 (1.0)      16.4041 (1.0)      0.2354 (1.0)      16.4000 (1.0)      0.4397 (1.05)         18;0  60.9606 (1.0)          40           1
test_llama2_7b_sdpa_grad[thunder]                           16.1180 (1.00)     17.0175 (1.01)     16.4810 (1.00)     0.2475 (1.05)     16.4507 (1.00)     0.4206 (1.0)          17;0  60.6759 (1.00)         40           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

vedaanta avatar May 16 '24 21:05 vedaanta

Thank you!

@parthmannan how do you feel about this from perf perspective? one concern might be the _make_cudnn_sdpa_*_graph expense, but maybe that's 1) cheaper than i worry about, or 2) not that relevant because _cudnn_sdpa_checker doesn't get called all that often anyway ?

@tfogal I don't think we need to worry about this from a performance perspective as you pointed in 2. - I don't expect this to be called very often. Hopefully just the first iteration for static shapes if I am thinking about this correctly.

What happens with dynamic shapes? We eventually plan to support that, can cuDNN create broader graphs that work for many shapes or will we need to call this everytime?

parthmannan avatar May 17 '24 18:05 parthmannan

@t-vi this is ready for your final review and merge. :)

vedaanta avatar May 21 '24 18:05 vedaanta

@vedaanta I think the test failures are real / caused by this patch:

        jfn = thunder.jit(module)
>       result = jfn(*args, **kwargs)

thunder/tests/test_jit_general.py:608: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/module.py:49: in forward
    res = self._forward_fn(*args, **kwargs)
thunder/__init__.py:626: in fn_
    result = cache_entry.computation_fn(*inps)
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115: in decorate_context
    return func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:28: in decorate_autocast
    return func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:28: in decorate_autocast
    return func(*args, **kwargs)
thunder.computation_729:68: in computation
    (y, _, _, _) = cudnn_sdpa_fwd(q, t74, t78, None, 0, True, scale=None)
thunder/executors/cudnnex.py:357: in _cudnn_sdpa_fwd_impl
    with torch.cuda.device(query.device):
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:361: in __init__
    self.idx = _get_device_index(device, optional=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

device = device(type='cpu'), optional = True, allow_cpu = False

    def _get_device_index(
        device: Any, optional: bool = False, allow_cpu: bool = False
    ) -> int:
        r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
    
        If :attr:`device` is a torch.device object, returns the device index if it
        is a CUDA device. Note that for a CUDA device without a specified index,

this is ready for your final review and merge. :)

I think we should hold off pending investigation; please take a look and let's double check we're not causing regressions here.

tfogal avatar May 21 '24 19:05 tfogal

Yeah, as @tfogal , points out, I think something is up in the tests.

t-vi avatar May 24 '24 06:05 t-vi