xla
xla copied to clipboard
Enable Relu epilogue fusion for cublasLt matmul for training
To enable ReLU epilogue fusion for CublasLt matmul for training, 2 pair of epilogues: (RELU_AUX, DRELU) and (BIAS_RELU_AUX, DRELU_BGRAD) are added. The RELU_AUX(or BIAS_RELU_AUX) epilogue for the forward matmul outputs a bitmask in the auxiliary buffer for the backward matmul to use. This PR only targets both non-fp8 and fp8 matmuls. @kaixih @philipphack @reedwm
This PR aims to pattern match the following forward pass:
y = matmul1(x1, weight1)
y = relu(y)
y = matmul2(y, weight2)
and its corresponding backward pass
dy = grad_matmul2(dy, weight1)
dy = drelu(dy)
dy = grad_matmul1(dy, weight2)
Here matmul1, matmul2, grad_matmul1 and grad_matmul2 could also be fp8 matmuls and vector bias could be present. In the forward pass, ReLU is fused into matmul1 and in the backward pass, drelu is fused into grad_matmul2. For the matmul operation with the first matrix having dimensions m=16384, k=n=12288, the performance comparison yields:
With Optimization: Execution Time: 32.98 ms
Without Optimization: Execution Time: 33.52 ms
Hi @wenscarl, can you have a look at the comments? Thanks.
All issues have been addressed.
Hi @wenscarl , can you addressed the above suggested comments? Thanks.
Hi @reedwm , can you please have a look into this PR? Thanks.
@wenscarl Could you please help address @reedwm's comments?
@wenscarl It seems that PR has been inactive for two months. Please provide an update if you still intend to merge it, otherwise we will close it after some time of further inactivity.
This matmul-relu-matmul pattern is not so commonly seen among models. Closing.