dgl
dgl copied to clipboard
[Performance] GSpMM regression with FP16
I spent some time profiling the GAT example with AMP in https://docs.dgl.ai/en/0.9.x/guide/mixed_precision.html and want to know why we didn't obtain performance gain from FP16. I observed regression in both the forward and backward phases.
Performance on A100
AMP | FP32 | |
---|---|---|
forward | 254 ms | 160 ms |
backward | 482 ms | 369 ms |
step | 738 ms | 529 ms |
Take the forward phase for example, there are three EdgeSoftmax-GSpMM
pairs.
- AMP
- FP32
Elapsed times (ms) of the EdgeSoftmax-GSpMM
pairs x3.
AMP | FP32 | |
---|---|---|
EdgeSoftmax1 | 20.896 | 22.055 |
GSpMM1 | 73.851 | 31.890 |
EdgeSoftmax2 | 20.865 | 22.080 |
GSpMM2 | 73.797 | 31.882 |
EdgeSoftmax3 | 20.869 | 22.154 |
GSpMM3 | 27.528 | 13.422 |
We can see that the EdgeSoftmax
is a bit faster with AMP while it's GSpMM
that results in the regression.
Dig it a bit deeper, GSpMM
with FP32 will go through the cusparse::csrmm
code path while adopting dgl::aten::cuda::SpMMCsrKernel
with FP16 inputs. I'll look into the performance issue on cusparse with FP16.
ENV
- CUDA_VERSION: 11.7.0.022
- CUDA_DRIVER_VERSION: 515.43.04
- CUSPARSE_VERSION: 11.7.3.50
- GPU: NVIDIA A100 80GB PCIe
- PyTorch: 1.12.0a0+8a1a93a
I encountered a similar issue a year ao, I remember that cusparse fp16 requires some alignment (e.g. the array pointer address should be multiple of 16/32/64, etc) to be efficient.
@yaox12 Can you add cuda versions?
@yaox12 Can you add cuda versions?
Added.
@yaox12 I just found a code snippet about the alignment issue written by @nv-dlasalle in the last year: https://github.com/nv-dlasalle/dgl/commit/5fc6e9bfc5fbd59e0cf5dbc4510883a5a124a467
The matrix column needs to be aligned to 128 bytes for best performance.
I don't think this is an issue related to data alignment. Because in this GAT model, the node features are projected from 602 to 256-dimension tensors before invoking GSpMM.
We observed a big regression (~8x slower) of cusparseSpMM
with FP16 compared to FP32 on A100/V100/A5000/RTX3090. For example, on A100, it's 250.24 ms (fp16) vs 32.60 ms (fp32). We are tracking it internally.
Here is my code (I have enable cusparseSpMM
for FP16 in spmm.cu
):
import time
import argparse
import torch
import dgl
from dgl.data import RedditDataset
from dgl.ops import gspmm
parser = argparse.ArgumentParser()
parser.add_argument("--dtype", type=str, choices=['fp16', 'fp32'])
parser.add_argument("--feat-len", type=int, choices=[256, 602])
args = parser.parse_args()
if args.dtype == 'fp16':
dtype = torch.float16
else:
dtype = torch.float32
device = 'cuda:0'
g = RedditDataset(raw_dir='.')[0]
g = g.int().to(device)
u = torch.rand((g.number_of_src_nodes(), args.feat_len), dtype=dtype, device=device)
e = torch.rand((g.number_of_edges(),), dtype=dtype, device=device)
# dry run
for _ in range(5):
gspmm(g, 'mul', 'sum', u, e)
torch.cuda.synchronize()
tic = time.time()
for _ in range(10):
gspmm(g, 'mul', 'sum', u, e)
torch.cuda.synchronize()
print(f'{dtype}, {args.feat_len}: {(time.time() - tic) * 100:.2f} ms')
Not only about feature size you should also check the pointer address of each operand (A, B and C). To the best of my knowledge, achieving such performance on Reddit does not require any special optimizations.
Sputnik also has some issues about alignment for fp16: https://github.com/facebookresearch/xformers/issues/15
@yaox12 Thank you for sharing your code. I tested it on my side (A5000, cuda11.7), but I observed different results:
feat dim | fp32 (ms) | fp16 (ms) | Speed-up of fp16 |
---|---|---|---|
32 | 17.6 | 18.75 | 0.938666667 |
64 | 21.35 | 22.1 | 0.966063348 |
128 | 40.46 | 28.7 | 1.409756098 |
256 | 78.24 | 55.59 | 1.407447383 |
512 | 156.18 | 100.48 | 1.554339172 |
602 | 203.78 | 140.86 | 1.446684651 |
I also forced the code using cusparseSpMM for FP16. Note that in FP16 format, I have to use CUDA_R_32F
for the computation datatype computeType
when calling cusparseSpMM
, according to: https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmm
Hi @yaox12 , is this done already?
Still tracking it internally.