tch-rs icon indicating copy to clipboard operation
tch-rs copied to clipboard

SDPA Flash kernel

Open finnkauski opened this issue 5 months ago • 6 comments

Hi,

The 2.2 release of Torch added a great integration to speed up transformer attention based architectures.

FlashAttentionV2 backend for scaled dot product attention

An example of how it is intended to be used in the Python API is this context manager.

Following down the rabbit hole we land here:

torch._C._set_sdp_use_flash(enabled)

and following this we land here

There are the other mem_efficient and math optimisations which are closely linked and worth exposing.

NOTE: not sure if this is already in the bindings and I just can't find it?

finnkauski avatar Feb 10 '24 18:02 finnkauski

I don't think this is available yet, and for these very specific operations I'm a bit dubious we will have some good support anytime soon. If you care about using flash-attention, I would recommend switching to candle instead where this is already well supported and give you access to flash-attention v2 which is significantly faster than the first version.

LaurentMazare avatar Feb 10 '24 20:02 LaurentMazare

The way I've read the info about this seems like it's some form of switch that gets flicked (unknown how it actually works under the hood, but as far as the calls to enable flash attention feels like it's some global state almost) so the Tensor::scaled_dot_product_attention dispatches some other kernels under the hood. I was hoping it would be something they would expose for us to also be able to toggle. This query is not really a now now issue, but I am thinking ahead for what I might need and hence the question.

If you think we can't expose this here then I think we can close this issue.

Side note: I'm a big fan of candle but I'm unfortunately working backwards on my project from candle back to torch as candle needs a few operations such as CUDA ConvTranspose1d that aren't implemented yet for my use-case and my naive kernel for it has been embarrassingly slow! And even without it (i.e. if we took the time it took to run that kernel, I couldn't get performance to the level I could with torch for now). I've go the codebase in two branches now candle and tch for reference, so might be able to contribute some insight into the comparison for my use case in the future.

finnkauski avatar Feb 10 '24 21:02 finnkauski

What are the ops that you're missing on the candle side? If it's just a matter of Conv1d and ConvTranspose1d, we could have a look at hooking cudnn here, this should bring you the best available kernels for these ops and you would benefit from all the "modern" aspects of candle as an ML framework which tch doesn't have.

LaurentMazare avatar Feb 11 '24 13:02 LaurentMazare

It was indeed those two mainly for this part of the project. I'm sure other bits might hit some walls too.

Essentially here's the issue on the candle side and it's ConvTranspose1d that I'm missing and obviously the faster the regular Conv1d is the better for my case.

I assumed you folks are swamped with TODO's so was just going to wait until this gets picked up to explore candle further.

finnkauski avatar Feb 11 '24 13:02 finnkauski

Ah, hooking conv1d to use cudnn is probably the easiest thing to do but won't help the convtranspose1d bit here. I'll have to dig a bit to see how to do the transposed versions in cudnn (I think that's what pytorch does).

LaurentMazare avatar Feb 11 '24 13:02 LaurentMazare

Just to mention that I merged a naive conv-transpose1d cuda implementation on the candle side. It's certainly on the very slow side so can well be a bottleneck for your use case but hopefully I can make a cudnn version for it too.

LaurentMazare avatar Feb 12 '24 14:02 LaurentMazare