Kaixi Hou

Results 15 issues of Kaixi Hou

After some more extensive tests, we think we need to revert those PRs regarding the NHWC+TF32 changes: https://github.com/tensorflow/tensorflow/pull/55761 https://github.com/tensorflow/tensorflow/pull/55806 https://github.com/tensorflow/tensorflow/pull/55920 During the tests, we noticed that in some cases the...

awaiting review
ready to pull
size:M

This PR enables the cudnn matmul fusion backend for supporting the generic matmul fusion patterns. Specifically, this PR focuses on the matmul+bias+tanh|sigmoid pattern. This PR is on top of these...

awaiting review
size:XL
comp:core

This PR enables the cudnn matmul fusion backend for supporting the generic matmul fusion patterns. Specifically, this PR focuses on the matmul+bias+gelu_exact pattern. (Note, the matmul+bias+gelu_approximate has already been supported...

awaiting review
size:L
comp:core

This PR enables the cudnn matmul fusion backend for supporting the generic matmul fusion patterns. Specifically, this PR focuses on the matmul+bias+gelu_exact pattern. (Note, the matmul+bias+gelu_approximate has already been supported...

awaiting review
size:L

This PR adds the support of fusion patterns of Conv+Bias+Relu6/Elu/LeakyRelu on GPUs. This is realized by using the CuDNN graph API which can utilize the runtime compiled kernels for Ampere...

awaiting review
ready to pull
size:L

This pull request introduces a custom data type rule for the FP8 parameters to implement custom gradient accumulation. Specifically, when reusing the FP8 parameters, the autograd will accumulate their gradients....

This PR is to allow users to enable the cudnn flash attention. The PR depends on https://github.com/google/praxis/pull/53. The preliminary results for the GPT3-5B, we can observe ~30% perf improve on...

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention...

This PR renames the original `fm32` to `fp32_max_grad` to express the idea of the dtype is used for storing fp32 values and using max for the gradient accumulation. cc. @nouiz

For current fp8 gemm, we set the c_scale to one, though it is effectively never used. Newer cublaslt, however, has a stricter requirement that c_scale can be set only when...