NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

Use torch sdpa implementation in ASR mha

Open WoodieDudy opened this issue 1 year ago • 16 comments

Hola. I changed the mha implementation for the ASR modules so that it uses torch.nn.functional.scaled_dot_product_attention.
This accelerated forward in the mha by 27% and backward by 17% on the A100.
Pytorch sdpa is continuously being optimized, ensuring that we benefit from the latest performance improvements.
My code uses memory efficient backend in sdpa because flash attention doesn't support custom attention bias. There is ongoing work to contribute custom bias support in the flash-attention repository. PR.

What else do I need to do to merge this pr?

Usage

There is also my benchmark:

import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark
from nemo.collections.asr.parts.submodules.multi_head_attention import RelPositionMultiHeadAttention

torch.manual_seed(123)

device = "cuda"
batch_size = 32
seq_len = 1024
d_model = 512
n_head = 8

query = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
key = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False)
mask = torch.triu(mask, diagonal=1).bool() # mask: True - make zero, False - leave unchanged 
mask = None
pos_emb = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)

attention_sdpa = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None, use_pytorch_sdpa=True).to(device)
attention_original = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None, use_pytorch_sdpa=False).to(device)
for original_param, sdpa_param in zip(attention_original.parameters(), attention_sdpa.parameters()):
    original_param.data.copy_(sdpa_param.data)

# attention_sdpa = torch.compile(attention_sdpa)
# attention_original = torch.compile(attention_original)


def measure_time(attention, query, key, value, mask, pos_emb):
    timer = benchmark.Timer(
        stmt='attention(query, key, value, mask, pos_emb);torch.cuda.synchronize()',
        setup='torch.cuda.synchronize()',
        globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
    )

    with torch.no_grad():
        torch.cuda.synchronize()
        results = timer.blocked_autorange(min_run_time=10)
        forward_time = results.mean
        output = attention(query, key, value, mask, pos_emb)
    return forward_time, output


def measure_fwd_bwd_time(attention, query, key, value, mask, pos_emb):
    timer = benchmark.Timer(
        stmt='loss=attention(query, key, value, mask, pos_emb).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize()',
        globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
    )
    torch.cuda.synchronize()
    results = timer.blocked_autorange(min_run_time=10)
    fwd_bwd_time = results.mean
    return fwd_bwd_time


time_fwd_original, output_original = measure_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask, pos_emb)

print(f"Original implementation time: {time_fwd_original:.6f} seconds")
print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds")
print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%")

time_fwd_bwd_original = measure_fwd_bwd_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_bwd_sdpa = measure_fwd_bwd_time(attention_sdpa, query, key, value, mask, pos_emb)
time_bwd_original = time_fwd_bwd_original - time_fwd_original
time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa

print(f"Original implementation backward time: {time_bwd_original:.6f} seconds")
print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds")
print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%")

print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}")

# Original implementation time: 0.049075 seconds
# SDPA implementation time: 0.035598 seconds
# SDPA boost 27.46%
# Original implementation backward time: 0.127004 seconds
# SDPA implementation backward time: 0.104986 seconds
# SDPA backward boost 17.34%
# Outputs are the same

PR Type:

  • [x] New Feature
  • [ ] Bugfix
  • [ ] Documentation

Who can review?

cc @titu1994 @SeanNaren

Additional Information

WoodieDudy avatar Jul 02 '24 12:07 WoodieDudy

I also attempted to run the tests in the repository but encountered an issue. NaNs appear when a mask with a fully False row is passed to MHA. Because of such mask, filling the matrix_bd with -inf values using matrix_bd.masked_fill_(mask.logical_not(), float("-inf")) results in a row of only -inf, and after the softmax, this entire row becomes NaNs. I am unsure how to resolve this since the softmax and multiplication by value occur within torch.nn.functional.scaled_dot_product_attention, and I cannot intervene. In your implementation, this is handled by manually filling with zeros attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0).

What can be done about this? Should we write a separate test that doesn't use such masks as input?

WoodieDudy avatar Jul 03 '24 06:07 WoodieDudy

@SeanNaren @titu1994 An option of using SDPA is a good thing also because a Triton-based version of FAv2 with custom attn_bias support (FlexAttention) is being added into PyTorch core: https://github.com/pytorch/pytorch/pull/130250#issuecomment-2216318222, so Conformer attention can benefit in the future from the speed-ups and proper compilation of SDPA in core PyTorch developments

vadimkantorov avatar Jul 09 '24 09:07 vadimkantorov

@SeanNaren @titu1994 haha, and now that FAv3 is out, probably PyTorch would integrate it as well in some near term - for maximum brr on H100 :) so having Nemo's Conformer auto-benefitting from this new work would be awesome

vadimkantorov avatar Jul 12 '24 13:07 vadimkantorov

cc @redoctopus @jbalam-nv @okuchaiev

WoodieDudy avatar Jul 18 '24 10:07 WoodieDudy

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Aug 02 '24 01:08 github-actions[bot]

stale bump

vadimkantorov avatar Aug 02 '24 20:08 vadimkantorov

Thanks for the contribution! @titu1994 please take a look at this PR as it looks like an interesting addition to speedup conformer models.

Just some notes:

  • Please use -10000 instead of -inf if it is possible as -inf may cause NAN with some data types.

  • Please add it as a config to the config files somewhere like here to be able to control it from configs: https://github.com/NVIDIA/NeMo/blob/07d536babf25f34d2866e7fa76c05eeef3b0086b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml#L147 Name is "use_pytorch_sdpa"?

  • I suggest to make it True as default if we can make sure it works in all cases? @titu1994 what do you think?

  • Please evaluate one of the pretrained models on NGC on a test-other LS to make sure that it produces the same exact output and accuracy.

  • You need to set the dropout to zero manually in non-training model as sdpa does not respect that and it always uses the dropout.

  • Have you used matrix_ac in your code/calculations?

VahidooX avatar Aug 07 '24 21:08 VahidooX

@VahidooX Big thanks for review!

  • I replaced -inf with -10000
  • Added use_pytorch_sdpa to config
  • Fixed dropout for torch sdpa
  • I dont calculate matrix_ac manually but it is calculated under the hood of torch.nn.functional.scaled_dot_product_attention (look on implementation example).
    attn_weight = q_with_bias_u @ key.transpose(-2, -1) * scale_factor
    # so matrix_ac would be equivalent to attn_weight
    

Do you have any script for calculating metrics on LS? And reference metrics?

WoodieDudy avatar Aug 12 '24 17:08 WoodieDudy

Thanks @titu1994 Okey, I'll try to add tests.

WoodieDudy avatar Aug 12 '24 17:08 WoodieDudy

I'm working with tests and faced a problem with tests/collections/asr/test_conformer_encoder.py:test_stochastic_depth_forward. But I think that in this test, the data in random_length is incorrect, because random_length must one dim with shape batch, not two dims. Am I right? https://github.com/NVIDIA/NeMo/blob/6774fdcf23a915421135b223404b2213bf2c1b7a/tests/collections/asr/test_conformer_encoder.py#L84

WoodieDudy avatar Aug 14 '24 12:08 WoodieDudy

I'm working with tests and faced a problem with tests/collections/asr/test_conformer_encoder.py:test_stochastic_depth_forward. But I think that in this test, the data in random_length is incorrect, because random_length must one dim with shape batch, not two dims. Am I right?

https://github.com/NVIDIA/NeMo/blob/6774fdcf23a915421135b223404b2213bf2c1b7a/tests/collections/asr/test_conformer_encoder.py#L84

Yes, that looks like to be incorrect.

VahidooX avatar Aug 15 '24 01:08 VahidooX

As I mentioned above here, I encountered an error caused by the mask for padding tokens being all False, which led to NaNs appearing after the softmax operation.

I resolved this issue by adding dummy columns and rows to the query, key, value, and mask https://github.com/NVIDIA/NeMo/blob/8d292d6ef218d5dc5f40878cde2052448e08aabb/nemo/collections/asr/parts/submodules/multi_head_attention.py#L160-L162 However, this may not be the most appropriate solution since it could affect the divisibility of dimensions by powers of two.

Do you think this approach is acceptable?

Recently, a similar issue was resolved for SDPA https://github.com/pytorch/pytorch/pull/131863. I guess when this fix is included in the global version of PyTorch, we can remove this workaround involving the additional columns.

WoodieDudy avatar Aug 21 '24 17:08 WoodieDudy

Maybe we should run a training test with mem efficient backend before setting use_pytorch_sdpa=True by default. Does anyone have a ready script for this?

WoodieDudy avatar Aug 21 '24 18:08 WoodieDudy

As I mentioned above here, I encountered an error caused by the mask for padding tokens being all False, which led to NaNs appearing after the softmax operation.

I resolved this issue by adding dummy columns and rows to the query, key, value, and mask

https://github.com/NVIDIA/NeMo/blob/8d292d6ef218d5dc5f40878cde2052448e08aabb/nemo/collections/asr/parts/submodules/multi_head_attention.py#L160-L162

However, this may not be the most appropriate solution since it could affect the divisibility of dimensions by powers of two. Do you think this approach is acceptable?

Recently, a similar issue was resolved for SDPA pytorch/pytorch#131863. I guess when this fix is included in the global version of PyTorch, we can remove this workaround involving the additional columns.

Hmm, did you try to use -10000 instead of -inf and then slice off the excess as suggested in the above pr as option 2?

If that doesn't work, I'm wondering if we should merge this without the WAR and guard this code path for only after pytorch 2.4 is released ? We can do a version check and run the correct code for relposmha.

We can fall back to older codebase for lower pytorch versions. @VahidooX what do you think ?

titu1994 avatar Aug 22 '24 00:08 titu1994

Oh god, sorry for all the mess around with git :(

WoodieDudy avatar Aug 26 '24 18:08 WoodieDudy

So, @titu1994 , yes, it was possible to get rid of these NaNs after softmax by simply zeroing them out using a mask, which is what I just did.

out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_rate)

# this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
if mask is not None:
    all_masked_rows = torch.all(~mask, dim=-1)
    all_masked_rows.unsqueeze_(-1)
    out = out.masked_fill(all_masked_rows, 0.0)

In the end, it seems that before merging with the flag set to True by default, it would be good to run a training test to ensure everything works fine. Does anyone have the code ready for this?

WoodieDudy avatar Aug 26 '24 18:08 WoodieDudy

I'm on it - launched Canary-1B training with this, will keep you posted.

pzelasko avatar Aug 29 '24 15:08 pzelasko

Unfortunately the Canary-1B training on 32 GPUs with bf16 AMP is diverging after ~10 hours with this SDPA implementation. I launched a second run with a different seed to confirm and it replicates. I have a ~1 month old baseline run that works fine, and I launched another from today's main that also runs fine. Any ideas what may have caused this?

image

pzelasko avatar Aug 30 '24 21:08 pzelasko

@pzelasko I think the first step should be running SDPA with the most basic math algo selected:

with torch.backends.cuda.sdpa_kernel(backends=[SDPBackend.MATH]): # needs to be checked if this breaks torch.compile(fullgraph=True) or not, see https://github.com/pytorch/pytorch/issues/130098
    # by default running with SDPBackend.EFFICIENT_ATTENTION as it's supposedly the only backend besides MATH which supports passed custom attn_bias arg
    x = F.sdpa(...)

This would validate/benchmark that the basic SDPA impl works fine (should be more similar to the original Nemo's impl). I heard somewhere on PyTorch forum/issues that numerics of different SDPA methods are slightly different and this can lead to blow-ups during internal accumulation :( This is why I think use_pytorch_sdpa should not be default just yet and that the SDPA backend selection should be exposed in config too. In future, if @WoodieDudy takes on https://github.com/Dao-AILab/flash-attention/pull/617, FLASH_ATTENTION backend would also become available. And not sure if modern CUDNN_ATTENTION can handle custom attention bias - maybe it's also available now in CUDNN_ATTENTION

For more collaborative debugging, it would be great if some subset of mutlilingual ASRSET-1 was published as filelists. Then relatively large (larger than LibriSpeech) train runs can be done by the users for validation (e.g. to try again using RoPE or AliBI positional encoding + FAv2) and more reproducible research

maybe need to look at loss spikes, activation range, weight range... maybe can be fixed by simply clipping some qkv projection weights... or skipping batches with large loss (after initial training) - this is not very trivial to do with DDP as replicas must be in sync, but still possible...

vadimkantorov avatar Aug 31 '24 09:08 vadimkantorov

@pzelasko what pytorch version did you use? We found that the latest versions where more stable. Did you try with bf16_mixed?

orena1 avatar Aug 31 '24 18:08 orena1

So, let's make use_pytorch_sdpa False by default and merge this PR? And then may be continue work in different PR

WoodieDudy avatar Sep 10 '24 20:09 WoodieDudy

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Sep 25 '24 01:09 github-actions[bot]

That sounds fine with me. Can you make it False by default. @titu1994 wdyt?

nithinraok avatar Sep 25 '24 02:09 nithinraok

Sounds ok to me

titu1994 avatar Sep 25 '24 05:09 titu1994

I set the flag everywhere to false by default. And also added an argument use_pytorch_sdpa_backends in which you can set the list of backends for sdpa

WoodieDudy avatar Sep 26 '24 18:09 WoodieDudy

Let's merge? @titu1994

WoodieDudy avatar Sep 26 '24 18:09 WoodieDudy

It looks ok for inference at least, lets avoid if for training for now

titu1994 avatar Sep 27 '24 02:09 titu1994

@VahidooX @pzelasko 👀

WoodieDudy avatar Oct 01 '24 05:10 WoodieDudy

Uh, sorry for the lack of responsiveness lately. I used bf16-mixed and PyTorch 2.4.0 with CUDA 12.5. I agree we should merge it disabled by default. Let's take a look again later with the newer backends.

We should also enable it for inference by default if possible.

pzelasko avatar Oct 01 '24 15:10 pzelasko

Just initiated the CI run: looks like some of them are failing like this: https://github.com/NVIDIA/NeMo/actions/runs/11128061705/job/30922390803?pr=9590

address them and once CI passes this is good to go. Thanks!

nithinraok avatar Oct 01 '24 15:10 nithinraok