NeMo
NeMo copied to clipboard
Use torch sdpa implementation in ASR mha
Hola. I changed the mha implementation for the ASR modules so that it uses torch.nn.functional.scaled_dot_product_attention.
This accelerated forward in the mha by 27% and backward by 17% on the A100.
Pytorch sdpa is continuously being optimized, ensuring that we benefit from the latest performance improvements.
My code uses memory efficient backend in sdpa because flash attention doesn't support custom attention bias. There is ongoing work to contribute custom bias support in the flash-attention repository. PR.
What else do I need to do to merge this pr?
Usage
There is also my benchmark:
import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark
from nemo.collections.asr.parts.submodules.multi_head_attention import RelPositionMultiHeadAttention
torch.manual_seed(123)
device = "cuda"
batch_size = 32
seq_len = 1024
d_model = 512
n_head = 8
query = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
key = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False)
mask = torch.triu(mask, diagonal=1).bool() # mask: True - make zero, False - leave unchanged
mask = None
pos_emb = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
attention_sdpa = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None, use_pytorch_sdpa=True).to(device)
attention_original = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None, use_pytorch_sdpa=False).to(device)
for original_param, sdpa_param in zip(attention_original.parameters(), attention_sdpa.parameters()):
original_param.data.copy_(sdpa_param.data)
# attention_sdpa = torch.compile(attention_sdpa)
# attention_original = torch.compile(attention_original)
def measure_time(attention, query, key, value, mask, pos_emb):
timer = benchmark.Timer(
stmt='attention(query, key, value, mask, pos_emb);torch.cuda.synchronize()',
setup='torch.cuda.synchronize()',
globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
)
with torch.no_grad():
torch.cuda.synchronize()
results = timer.blocked_autorange(min_run_time=10)
forward_time = results.mean
output = attention(query, key, value, mask, pos_emb)
return forward_time, output
def measure_fwd_bwd_time(attention, query, key, value, mask, pos_emb):
timer = benchmark.Timer(
stmt='loss=attention(query, key, value, mask, pos_emb).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize()',
globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
)
torch.cuda.synchronize()
results = timer.blocked_autorange(min_run_time=10)
fwd_bwd_time = results.mean
return fwd_bwd_time
time_fwd_original, output_original = measure_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask, pos_emb)
print(f"Original implementation time: {time_fwd_original:.6f} seconds")
print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds")
print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%")
time_fwd_bwd_original = measure_fwd_bwd_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_bwd_sdpa = measure_fwd_bwd_time(attention_sdpa, query, key, value, mask, pos_emb)
time_bwd_original = time_fwd_bwd_original - time_fwd_original
time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa
print(f"Original implementation backward time: {time_bwd_original:.6f} seconds")
print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds")
print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%")
print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}")
# Original implementation time: 0.049075 seconds
# SDPA implementation time: 0.035598 seconds
# SDPA boost 27.46%
# Original implementation backward time: 0.127004 seconds
# SDPA implementation backward time: 0.104986 seconds
# SDPA backward boost 17.34%
# Outputs are the same
PR Type:
- [x] New Feature
- [ ] Bugfix
- [ ] Documentation
Who can review?
cc @titu1994 @SeanNaren
Additional Information
- Related to (issue)
I also attempted to run the tests in the repository but encountered an issue. NaNs appear when a mask with a fully False row is passed to MHA. Because of such mask, filling the matrix_bd with -inf values using matrix_bd.masked_fill_(mask.logical_not(), float("-inf")) results in a row of only -inf, and after the softmax, this entire row becomes NaNs. I am unsure how to resolve this since the softmax and multiplication by value occur within torch.nn.functional.scaled_dot_product_attention, and I cannot intervene. In your implementation, this is handled by manually filling with zeros attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0).
What can be done about this? Should we write a separate test that doesn't use such masks as input?
@SeanNaren @titu1994 An option of using SDPA is a good thing also because a Triton-based version of FAv2 with custom attn_bias support (FlexAttention) is being added into PyTorch core: https://github.com/pytorch/pytorch/pull/130250#issuecomment-2216318222, so Conformer attention can benefit in the future from the speed-ups and proper compilation of SDPA in core PyTorch developments
@SeanNaren @titu1994 haha, and now that FAv3 is out, probably PyTorch would integrate it as well in some near term - for maximum brr on H100 :) so having Nemo's Conformer auto-benefitting from this new work would be awesome
cc @redoctopus @jbalam-nv @okuchaiev
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.
stale bump
Thanks for the contribution! @titu1994 please take a look at this PR as it looks like an interesting addition to speedup conformer models.
Just some notes:
-
Please use -10000 instead of -inf if it is possible as -inf may cause NAN with some data types.
-
Please add it as a config to the config files somewhere like here to be able to control it from configs: https://github.com/NVIDIA/NeMo/blob/07d536babf25f34d2866e7fa76c05eeef3b0086b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml#L147 Name is "use_pytorch_sdpa"?
-
I suggest to make it True as default if we can make sure it works in all cases? @titu1994 what do you think?
-
Please evaluate one of the pretrained models on NGC on a test-other LS to make sure that it produces the same exact output and accuracy.
-
You need to set the dropout to zero manually in non-training model as sdpa does not respect that and it always uses the dropout.
-
Have you used matrix_ac in your code/calculations?
@VahidooX Big thanks for review!
- I replaced -inf with -10000
- Added use_pytorch_sdpa to config
- Fixed dropout for torch sdpa
- I dont calculate
matrix_acmanually but it is calculated under the hood of torch.nn.functional.scaled_dot_product_attention (look on implementation example).attn_weight = q_with_bias_u @ key.transpose(-2, -1) * scale_factor # so matrix_ac would be equivalent to attn_weight
Do you have any script for calculating metrics on LS? And reference metrics?
Thanks @titu1994 Okey, I'll try to add tests.
I'm working with tests and faced a problem with tests/collections/asr/test_conformer_encoder.py:test_stochastic_depth_forward.
But I think that in this test, the data in random_length is incorrect, because random_length must one dim with shape batch, not two dims. Am I right?
https://github.com/NVIDIA/NeMo/blob/6774fdcf23a915421135b223404b2213bf2c1b7a/tests/collections/asr/test_conformer_encoder.py#L84
I'm working with tests and faced a problem with
tests/collections/asr/test_conformer_encoder.py:test_stochastic_depth_forward. But I think that in this test, the data inrandom_lengthis incorrect, becauserandom_lengthmust one dim with shapebatch, not two dims. Am I right?https://github.com/NVIDIA/NeMo/blob/6774fdcf23a915421135b223404b2213bf2c1b7a/tests/collections/asr/test_conformer_encoder.py#L84
Yes, that looks like to be incorrect.
As I mentioned above here, I encountered an error caused by the mask for padding tokens being all False, which led to NaNs appearing after the softmax operation.
I resolved this issue by adding dummy columns and rows to the query, key, value, and mask https://github.com/NVIDIA/NeMo/blob/8d292d6ef218d5dc5f40878cde2052448e08aabb/nemo/collections/asr/parts/submodules/multi_head_attention.py#L160-L162 However, this may not be the most appropriate solution since it could affect the divisibility of dimensions by powers of two.
Do you think this approach is acceptable?
Recently, a similar issue was resolved for SDPA https://github.com/pytorch/pytorch/pull/131863. I guess when this fix is included in the global version of PyTorch, we can remove this workaround involving the additional columns.
Maybe we should run a training test with mem efficient backend before setting use_pytorch_sdpa=True by default. Does anyone have a ready script for this?
As I mentioned above here, I encountered an error caused by the mask for padding tokens being all False, which led to NaNs appearing after the softmax operation.
I resolved this issue by adding dummy columns and rows to the query, key, value, and mask
https://github.com/NVIDIA/NeMo/blob/8d292d6ef218d5dc5f40878cde2052448e08aabb/nemo/collections/asr/parts/submodules/multi_head_attention.py#L160-L162
However, this may not be the most appropriate solution since it could affect the divisibility of dimensions by powers of two. Do you think this approach is acceptable?
Recently, a similar issue was resolved for SDPA pytorch/pytorch#131863. I guess when this fix is included in the global version of PyTorch, we can remove this workaround involving the additional columns.
Hmm, did you try to use -10000 instead of -inf and then slice off the excess as suggested in the above pr as option 2?
If that doesn't work, I'm wondering if we should merge this without the WAR and guard this code path for only after pytorch 2.4 is released ? We can do a version check and run the correct code for relposmha.
We can fall back to older codebase for lower pytorch versions. @VahidooX what do you think ?
Oh god, sorry for all the mess around with git :(
So, @titu1994 , yes, it was possible to get rid of these NaNs after softmax by simply zeroing them out using a mask, which is what I just did.
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_rate)
# this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
if mask is not None:
all_masked_rows = torch.all(~mask, dim=-1)
all_masked_rows.unsqueeze_(-1)
out = out.masked_fill(all_masked_rows, 0.0)
In the end, it seems that before merging with the flag set to True by default, it would be good to run a training test to ensure everything works fine. Does anyone have the code ready for this?
I'm on it - launched Canary-1B training with this, will keep you posted.
Unfortunately the Canary-1B training on 32 GPUs with bf16 AMP is diverging after ~10 hours with this SDPA implementation. I launched a second run with a different seed to confirm and it replicates. I have a ~1 month old baseline run that works fine, and I launched another from today's main that also runs fine. Any ideas what may have caused this?
@pzelasko I think the first step should be running SDPA with the most basic math algo selected:
with torch.backends.cuda.sdpa_kernel(backends=[SDPBackend.MATH]): # needs to be checked if this breaks torch.compile(fullgraph=True) or not, see https://github.com/pytorch/pytorch/issues/130098
# by default running with SDPBackend.EFFICIENT_ATTENTION as it's supposedly the only backend besides MATH which supports passed custom attn_bias arg
x = F.sdpa(...)
This would validate/benchmark that the basic SDPA impl works fine (should be more similar to the original Nemo's impl). I heard somewhere on PyTorch forum/issues that numerics of different SDPA methods are slightly different and this can lead to blow-ups during internal accumulation :( This is why I think use_pytorch_sdpa should not be default just yet and that the SDPA backend selection should be exposed in config too. In future, if @WoodieDudy takes on https://github.com/Dao-AILab/flash-attention/pull/617, FLASH_ATTENTION backend would also become available. And not sure if modern CUDNN_ATTENTION can handle custom attention bias - maybe it's also available now in CUDNN_ATTENTION
For more collaborative debugging, it would be great if some subset of mutlilingual ASRSET-1 was published as filelists. Then relatively large (larger than LibriSpeech) train runs can be done by the users for validation (e.g. to try again using RoPE or AliBI positional encoding + FAv2) and more reproducible research
maybe need to look at loss spikes, activation range, weight range... maybe can be fixed by simply clipping some qkv projection weights... or skipping batches with large loss (after initial training) - this is not very trivial to do with DDP as replicas must be in sync, but still possible...
@pzelasko what pytorch version did you use? We found that the latest versions where more stable. Did you try with bf16_mixed?
So, let's make use_pytorch_sdpa False by default and merge this PR? And then may be continue work in different PR
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.
That sounds fine with me. Can you make it False by default. @titu1994 wdyt?
Sounds ok to me
I set the flag everywhere to false by default. And also added an argument use_pytorch_sdpa_backends in which you can set the list of backends for sdpa
Let's merge? @titu1994
It looks ok for inference at least, lets avoid if for training for now
@VahidooX @pzelasko 👀
Uh, sorry for the lack of responsiveness lately. I used bf16-mixed and PyTorch 2.4.0 with CUDA 12.5. I agree we should merge it disabled by default. Let's take a look again later with the newer backends.
We should also enable it for inference by default if possible.
Just initiated the CI run: looks like some of them are failing like this: https://github.com/NVIDIA/NeMo/actions/runs/11128061705/job/30922390803?pr=9590
address them and once CI passes this is good to go. Thanks!