Masaki Kozuki
Masaki Kozuki
@jqueguiner Thanks for the PR but I'm afraid we need to relax the condition in backward path as well in https://github.com/NVIDIA/apex/blob/f7421555c3d2ff01eed0e7c0c4321f3e4dd58fc6/apex/contrib/csrc/fmha/fmha_api.cpp#L182-L191. Thank you for your patience in my delayed response.
I found that this FMHA failed when I ran https://github.com/NVIDIA/apex/blob/master/apex/contrib/test/fmha/test_fmha.py with this PR on A40 because the kernel might violate some limitation unfortunately... cc'ing @yjk21 for visibility
seems related to https://github.com/NVIDIA/NeMo/pull/3998
@LBNord @xiaoyu-work @Mortimerp9 Apologize for bothering you I've added a guard in #1253. Also, thank you for reporting this.
I want to have pytorch ship `ATen/native/utils/*.h`, `torch/csrc/jit/codegen/cuda/*.h`, and `torch/csrc/jit/codegen/cuda/ops/*.h` before merging this. For this, what I guess we need to do is to update - https://github.com/pytorch/pytorch/blob/543eaac415bfba59e648a3da8be5c4f964f9bc6e/setup.py#L938 - https://github.com/pytorch/pytorch/blob/543eaac415bfba59e648a3da8be5c4f964f9bc6e/aten/src/ATen/CMakeLists.txt#L149 -...
Getting closer to remove `PYTORCH_HOME` env var dependency. https://github.com/pytorch/pytorch/pull/78281
You need to build apex with `--global-option="--fast_layer_norm"` if you want to use apex/contrib/layer_norm `pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./`
Could you confirm that with this change we can build `--cuda_ext` extension on Linux?
Sorry for having been bothering you, I wrote a patch #1211. Could some of you try https://github.com/NVIDIA/apex/commit/25bfcb914796c7955956fc1e39839e7efed997d5 and let me know if the issue is solved or not?
I don't know your setup but one local run was as follows ``` root@a33ccb515d34:/opt/pytorch/apex# python apex/contrib/test/fmha/test_fmha.py Test s=128 b=32, zero_tensors=False Test s=128 b=32, zero_tensors=True .Test s=256 b=32, zero_tensors=False Test s=256...