pytorch_geometric
pytorch_geometric copied to clipboard
[Roadmap] CPU Performance Optimization for PyG
🚀 The feature, motivation and pitch
The goal of this roadmap is to optimize CPU performance for PyG (including torch_scatter, torch_sparse).
For the first step, we will start with single node inference performance optimization on:
- Homogeneous Models: GCN or GAT, PNA, EdgeConv
- Heterogeneous Models: to_hetero, R-GCN, R-GAT
Next step will extend to optimization effort to (distributed) training.
Performance Profiling
CPU platform: Icelake Xeon
Generic benchmarking
- [x] GCN + ogbn-products: (
torch_sparse::spmm_sum96.04%) - [x] GCN + reddits: layer=1, hidden=16 (
DataLoader83.49%,aten::scatter_add_8.47%) - [x] GCN + reddits: layer=3, hidden=32 (
DataLoader59.83%,aten::scatter_add_24.76%) - [x] SAGE + ogbn-products: (
aten::scatter_add_27.61%,DataLoader25.70%,aten::index20.26%) - [x] GAT + CiteSeer: (
aten::scatter_add_30.91%,torch_scatter::scatter_max24.54%,aten::mm10.34%,aten::index_select6.71%) most of models under pytorch_geometric/benchmark/citation have similar behavior from performance perspective. - [x] to_hetero_mag: (
aten::addmm21.69%,aten::scatter_add_20.60%,aten::index_select13.48%,DataLoader12.31%) - [x] PNA: (t
orch_scatter::scatter_max39.34%,torch_scatter::scatter_min39.25%); need follow up: need to get scatter_reduce tensor shape/stride (similar issue as aten::scatter_add_?) - [x] dynamicEdgeConv: (
torch_scatter::scatter_max66.91%,torch_cluster::knn23.56%) sourcebenchmark/points/edge_cnn.py - [x] EdgeConv: ((
torch_scatter::scatter_maxtorch_scatter::scatter_max 53.61%,aten::index_select21.73%,DataLoader16.11%) source from https://github.com/pyg-team/pytorch_geometric/pull/4915 - [x] pytorch_geometric/benchmark/kernel
- [ ] pytorch_geometric/benchmark/points
Large dataset benchmarking
- [x] GraphSAGE + mag240m profiling (table below)
- [x] analysis of ratio of profiler recorded time against total runtime (evaluate other overhead such numpy calling if any)
- [x] gather the input range for spmm_{sum|max|mean} for oneDNN RFC proposal (future plan)
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
enumerate(DataLoader)#_MultiProcessingDataLoaderIter... 71.27% 608.842s 71.39% 609.891s 70.223ms 8685
torch_sparse::spmm_mean 14.91% 127.390s 14.93% 127.522s 7.342ms 17370
aten::addmm 3.77% 32.166s 7.34% 62.727s 1.806ms 34740
aten::copy_ 3.60% 30.766s 3.60% 30.766s 161.007us 191082
aten::mm 2.29% 19.588s 2.30% 19.683s 1.133ms 17370
aten::native_batch_norm 0.94% 7.989s 1.01% 8.657s 332.256us 26055
DataLoader (with preprocess of input data) is the major bottleneck here, mostly from_numpy (246s) and to (169s) triggered by data type conversion, source from convert_batch.
Performance Hotspots
- DataLoader (mini-batch mode): mostly introduced by preprocessing by Samplers (e.g.
NeighborSampler). - edge_index in CSR: spmm_sum or spmm_max from torch_sparse (memory format CSR).
- edge_index in COO: scatter_add, scatter_max(torch_scatter), index_select, index, etc.
Python level API upgrade in model scripts
The DataLoader is a major hotspot so the first step is to upgrade DataLoader from NeightborSampler to NeighborLoader which has native C++ impelemtation:
- [ ] mag240m
- [ ] ogbn-products
- [ ] pytorch_geometric/benchmark/kernel
Native level kernel optimization
Phase One Optimizations
- [ ]
NeighborLoaderparallelization: the current impl is sequential (probably to avoid oversubscription with multiple workers on the data loader). Unlike GPU runs, asynchronously run data loading thread and computation thread does not always make sense. On some occasions, run data loading step and computation step sequentially while making each of the torch operator parallel on OpenMP (which is case of intra-parallelism) makes more sense. Hotspot ontorch_sparse::neighbor_sample. - [x]
aten::sort: GCN + ogbn-products spent roughly 1/3 time on sort in the preprocessing step (which is not covered during profiler result for the model inference), introduced by indexing from sparse tensor at gnn.py#L123. Root cause is aten::sort(dim) could only be paralleled on dimensions != dim, and the grain size is not correctly set. Fixed by #74897. - [ ]
spmm_{max|mean|sum}(torch_sparse). Add vectorization and prefetch (indirect memory access) and apply blocking on M and K (if necessary). - [ ]
scatte_addandscatter_max(torch_scatter). Optimizedscatter_add(with extended index) with #82703. Still need more polishing work.
the current impl for scatter_add will try to parallel on the inner dimension to avoid write conflict; while ideally we should try to parallel on the outer dimension and vectorize on the inner dimension, yet need to resolve the write conflict on the output tensor. Experiment different impls for the given input range.
- [x]
index_select, optimized via #76868. - [ ]
index, directly optimizeindexwould be difficult, maybe we can change it to more performance ops likeindex_selectfrom NeighborLoader or customize its kernel from NeighborLoader.
Phase Two Optimizations
- [ ] kernel fusion for GAS, #71300, maybe dispatch on
TensorTypeIdofCPUandSparseCPU. - [ ]
scatter_add: cache the sorted index. - [ ]
knn(torch_cluster), need follow up shape info to determine proper method to parallel the kernel. Evaluate knn fromoneAPIdal.
Design option for vectorization
To vectorize kernels from torch-sparse and torch-scatter, we have multiple options:
- vectorize inside torch-sparse/torch-scatter: the most simple way is to use
#pragma omp simdand add a compiler flagmarch=skylake-avx512but this won't apply bfloat16 (bfloat16 is a overload of uint16 and won't be vectorized properly by compiler) - vectorize inside torch-sparse/torch-scatter: use the wrapper of
at::vec::Vectorized<scalar_t>, this will apply to bfloat16 but we need to customize the cmake scripts to make it compatible with PyTorch's cpu build flags: _DEFAULT(scalar code), _AVX2 and _AVX512. - vectorize inside torch core: in this manner
at::vec::Vectorized<scalar_t>will work without any change but need to move the operator from torch-sparse/torch-scatter to torch. Makes more sense for the fused kernel of GAS.
(current decision is to go with option 3 as much as we can)
Bfloat16 enabling in torch-sparse/torch-scatter
(highly related to the vectorization method choosn)
- [ ] Need more work to determine operator list for bfloat16 support.
Validation
- [ ] verify float32 accuracy
- [ ] verify bfloat16 accuracy
Thanks for this detailed issue and roadmap. PyTorch recently released torch.scatter_reduce as well. As such, the long-term goal is to move to the PyTorch implementation of torch.scatter_reduce routines, and current optimizations of torch-scatter are properly not future-proof as a result. Can we also benchmark torch.scatter_reduce and torch_scatter to see if there is already a performance gain by simply switching the implementation?
Ok, I see. Then it is better to optimize scatter_reduce in torch. Just checked the code, scatter_add and scatter_reduce share the same kernel in torch so they have the same performance issues. Will have it fixed.
hey @mingfeima this issue is great, a lot of detail.
NeighborLoaderparallelization: the current impl is sequential (probably to avoid oversubscription with multiple workers on the data loader). Unlike GPU runs, asynchronously run data loading thread and computation thread does not always make sense. On some occasions, run data loading step and computation step sequentially while making each of the torch operator parallel on OpenMP (which is case of intra-parallelism) makes more sense.
I'm a complete newbie to this, so my question is to learn not suggest something. Can you explain what you're intending to change here? It sounds like you want to keep it sequential but maybe the actual sampling itself parallel?
I'm a complete newbie to this, so my question is to learn not suggest something. Can you explain what you're intending to change here? It sounds like you want to keep it sequential but maybe the actual sampling itself parallel?
Yes, that's the idea! The data loader from pytorch is more suitable for GPU (by setting num_workers=N, it will launch data loading threads asynchronously with the main computation thread). On CPU, it is probably better to run data loading and computation in sequential while parallel the sampler from in the data loader with OpenMP threads.
That makes complete sense. Let me know if you'd like any help (I'd likely not be quick though 😅)
@mingfeima I assume the benchmarks have been run with num_workers=0? This explains why this is a bottleneck. Can you share some insights on how an OpenMP implementation of sampling behaves in relation to num_workers>0? Is it expected that this will potentially slow down the code compared to a single threaded implementation that utilizes parallelism solely on the worker level?
current benchmark profiling result uses the default setting. Some scripts, for example to_hetero_mag would explicitly set the num_workers, if not the pytorch default setting will be 4.
DataLoader time in the benchmark profile result actually comprises of two parts:
- IO: load data from disk to memory
- pre processing: sampling, data type conversion, etc.
The second part takes more time, so it is still possible to be improved with single worker + parallel openmp. If we use num_workers>0, need to make sure openmp in the worker have correct setting (omp_num_threads and core affinity binding) to avoid over-subscription.
Actually the data loader optimization is a rather complexed issue, perhaps more complexed than optimizing the kernels :( since it is more likely a tuning job to achieve the most balanced situation between workload payload (memory footprint, computation complexity etc.) and hardware capacity (IO, memory bandwidth, ALU flops).
Usually we do not do data loading optimizations since the real case in deployment would probably be even more complexed (some venders have mechanisms like prefetching, batching to improve overall user experience and efficiency). But the thing is DGL has done some optimizations here so we need to at least something similar, otherwise out of box performance on PyG would look bad.
Anyway, we will make sure that openmp have correct settings either num_workers=0 or num_workers=N, and also each of the sampler can be properly paralleled. num_workers=0 benefits more for the pre processing and num_workers=N benefits more for the IO. And let the users to decide which way to go (maybe we can give a BKM or some simple guideline).
Updates on scatter_add optimizations, PR submitted at https://github.com/pytorch/pytorch/pull/82703
Initiative
Depending type of the edge_index, message passing will choose different paths: a) scatter_add for dense tensor; b) spmm for SparseTensor. The principal factor here is the memory format: While in the 1st case, the memory format for edge_index is COO and in 2nd case it is CSR.
Problem description
scatter_add is used to aggregate info in rowwise which means the index tensor is extended ( all rows have identical value).
A typical input shape for the dataset of reddit looks like:
self.sizes(): [135361, 256]; self.strides(): [256, 1]
index.sizes(): [477263, 256]; index.strides(): [1, 0]
src.sizes(): [477263, 256]; src.strides(): [256, 1]
So we pick rows from 477k indices and update dst index in self, ideally we want to parallel on outer dimension like 477k or 135k, but the scatter pattern indicate writes have conflicts among threads. The current ATen kernel choose to parallel on inner dimension of 256, which is not performant for the pyg usage: a) per thread memory access is non-contiguous; b) unable to be vectorized.
Algorithm
There exists a couple of algorithms to fix the write conflict such as: a) sorting; b) segment mutex; c) atomic; d) shared buffer, ... I choose a) sorting based on the input shape range which should be most performant. IFF anyone come up with better idea, please let me know :)
So,
- step-1: convert the
indexto CSR format, using paralleled radix sort - step-2: do a normal
spmmreduction
Result
I used the SAGE+reddit from https://github.com/pyg-team/pytorch_geometric/blob/master/examples/reddit.py
(Reason that I pick this one is that it used NeighborLoader which means I don't have to take data loader optimization into account for this one)
For inference, on ICX Xeon single socket, 20 cores @2.50GHz. End to end inference time reduced from 77.135s to 44.622s.
Attach part of the profiling logs, as we can see scatter_add reduced from 37.797s to 6.454s.
- before
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912
aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280
aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912
aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456
aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456
aten::mm 7.09% 5.472s 7.09% 5.472s 12.001ms 456
aten::index 5.89% 4.544s 5.90% 4.549s 9.845ms 462
aten::fill_ 3.59% 2.768s 3.59% 2.768s 2.014ms 1374
aten::zeros 0.01% 7.616ms 3.44% 2.652s 1.936ms 1370
aten::zero_ 0.00% 2.728ms 3.42% 2.636s 1.924ms 1370
- after
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::index_select 32.26% 14.395s 32.27% 14.398s 6.315ms 2280
aten::linear 0.01% 6.329ms 26.19% 11.688s 12.815ms 912
aten::scatter_add_ 14.46% 6.454s 14.46% 6.454s 7.077ms 912
aten::addmm 11.71% 5.223s 13.58% 6.060s 13.289ms 456
aten::matmul 0.01% 2.257ms 12.58% 5.612s 12.307ms 456
aten::mm 12.57% 5.610s 12.57% 5.610s 12.302ms 456
aten::index 9.98% 4.453s 9.99% 4.456s 9.646ms 462
aten::fill_ 5.62% 2.506s 5.62% 2.506s 1.369ms 1830
aten::zeros 0.02% 7.091ms 5.53% 2.466s 1.800ms 1370
aten::zero_ 0.01% 2.886ms 5.50% 2.453s 1.790ms 1370
aten::true_divide_ 0.02% 7.360ms 4.72% 2.106s 4.618ms 456
There still some TODOs to follow up which will bring some additional performance improvement:
- [ ] try cpu memory allocator
- [ ] add logic of
can_use_32bit_index - [ ] do blocking on nnz dimension
- [ ] add optimization when
indexandsrcare 1d, aka. inner_size is 1 - [ ] extend the current optimization to other reduction type, e.g.
max,min,mean - [ ] (TBD) extend the current optimization to
gatherwhich will be used for training
Just to understand: Does this mean that we first sort index and then do a segment reduction? In that case it might be good to preserve the information that index is sorted such that we do not have this overhead in consecutive GNN layers.
Just to understand: Does this mean that we first sort
indexand then do a segment reduction? In that case it might be good to preserve the information thatindexis sorted such that we do not have this overhead in consecutive GNN layers.
Yes, that's the idea. The overhead is not only sorting, also we have to calculate the row_ptr indices...
So the index should be constant? since it is an attribute from the dataset.
If we can cache the sorted index, scatter add performance could be further improved by roughly 1/3.
As far as I understand, this would refer to a segment_add implementation, correct? Similar to the one present in torch-scatter.segment. Is there also a chance we can optimize scatter_add without relying on sorting?
Yes, the current scatter_add is kind of like sorting + segment_add, and both parts are properly paralleled.
Because of the semantics limitation, we can not skip sorting since from PyTorch side there is no guarantee that index is in ascending order.
As for the optimization techniques of scatter_add, i tried a couple of methods: a) sorting (current submitted PR used this approach); b) mutex on a block of the write addresses; c) atomic on the most inner dimension; so on and so on. My experiment shows that a) performs best right now... Since sorting also helps to increase cache locality and we can enable blocking on the nnz dimensions (so it would be only one write for each row and multiple reads for src). For b) we can not do blocking on writes so there would be multiple writes as well; c) is only suitable for some inner_size, e.g. src: [135K, 1] and index: [477K, 1], atomic comes with higher price than normal FMA.
The overhead is actually not mainly from the sorting itself but from memory allocation, i will switch to c10 cpu allocator to see if it helps.
Aside from that, probably we have a chance to cache the sorted index so as to save sorting for consecutive layers. Maybe add a attribute called edge_index_sorted and for the 1st layer we fill it with the sorted index and in the consecutive layers we can directly use segment_add. (Of course also need to make sure that segment_add is fully optimized)
Only a rough idea at the moment, my point is that we firstly clear the performance bottlenecks from torch (so we optimize scatter_add as what it is and no upper level API change) and then seek more optimization oppotunities from pyg/torch-scatter/torch-sparse side, where we can make more aggressive optimizations.
Got it, thanks for clarifying!
Update on spmm optimizations, PR submitted at https://github.com/pytorch/pytorch/pull/83727.
Port spmm reduction from torch-sparse to torch, the current PR is only for demonstrating performance gains, API definition needs more amendment.
Now only sum is added, more will come in future (max, mean, min), the algorithm is pretty much the same.
Select benchmark from ./ogb/examples/nodeproppred/products/gnn.py, since originally this one spent majority of time on torch_sparse::spmm_sum. The spmm roughly got 5x speedup on my 20 core machine.
- before
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
torch_sparse::spmm_sum 97.09% 56.086s 97.09% 56.088s 6.232s 9
aten::linear 0.00% 85.000us 1.38% 795.485ms 88.387ms 9
aten::matmul 0.00% 57.000us 1.38% 795.260ms 88.362ms 9
aten::mm 1.38% 795.201ms 1.38% 795.203ms 88.356ms 9
aten::relu 0.00% 50.000us 0.76% 440.434ms 73.406ms 6
aten::clamp_min 0.76% 440.384ms 0.76% 440.384ms 73.397ms 6
aten::add_ 0.57% 327.801ms 0.57% 327.801ms 36.422ms 9
aten::log_softmax 0.00% 23.000us 0.10% 55.503ms 18.501ms 3
aten::_log_softmax 0.10% 55.480ms 0.10% 55.480ms 18.493ms 3
aten::argmax 0.09% 53.149ms 0.09% 53.153ms 13.288ms 4
aten::index 0.01% 5.771ms 0.01% 5.839ms 324.389us 18
aten::empty 0.00% 1.088ms 0.00% 1.088ms 77.714us 14
- after
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::spmm_sum 87.35% 11.826s 87.36% 11.827s 1.314s 9
aten::linear 0.00% 92.000us 5.87% 794.451ms 88.272ms 9
aten::matmul 0.00% 62.000us 5.87% 794.208ms 88.245ms 9
aten::mm 5.87% 794.143ms 5.87% 794.146ms 88.238ms 9
aten::relu 0.00% 53.000us 3.35% 452.977ms 75.496ms 6
aten::clamp_min 3.35% 452.924ms 3.35% 452.924ms 75.487ms 6
aten::add_ 2.58% 348.663ms 2.58% 348.663ms 38.740ms 9
aten::argmax 0.42% 57.473ms 0.42% 57.475ms 14.369ms 4
aten::log_softmax 0.00% 22.000us 0.39% 52.605ms 17.535ms 3
aten::_log_softmax 0.39% 52.583ms 0.39% 52.583ms 17.528ms 3
aten::index 0.04% 5.100ms 0.04% 5.174ms 287.444us 18
aten::empty 0.01% 1.097ms 0.01% 1.097ms 78.357us 14
To break down the optimization scheme a little bit:
- original (spmm): 56.086s
- naive vectorization: 29.314s
- unroll by 4: 25.664s
- rowwise blocking x16: 21.953s
- balanced thread partition: 11.826s
The balanced thread partition is targeting at balancing the thread payload. Basically if we directly parallel on row direction, it will be (I collect number of edges for each thread):
### thread: 0; min: 1; max: 17482; avg = 172.599
### thread: 1; min: 1; max: 9918; avg = 137.251
### thread: 2; min: 1; max: 5786; avg = 39.7606
### thread: 3; min: 1; max: 4062; avg = 40.0852
### thread: 4; min: 1; max: 10406; avg = 39.7207
### thread: 5; min: 1; max: 3491; avg = 40.0985
### thread: 6; min: 1; max: 5965; avg = 40.0117
### thread: 7; min: 1; max: 5865; avg = 40.3841
### thread: 8; min: 1; max: 5892; avg = 39.969
### thread: 9; min: 1; max: 6076; avg = 39.9995
### thread: 10; min: 1; max: 5215; avg = 40.0757
### thread: 11; min: 1; max: 3893; avg = 40.1075
### thread: 12; min: 1; max: 8052; avg = 39.8108
### thread: 13; min: 1; max: 4062; avg = 39.7186
### thread: 14; min: 1; max: 3243; avg = 40.3022
### thread: 15; min: 1; max: 5008; avg = 40.4213
### thread: 16; min: 1; max: 7657; avg = 40.0987
### thread: 17; min: 1; max: 6784; avg = 40.0618
### thread: 18; min: 1; max: 4810; avg = 39.8836
### thread: 19; min: 1; max: 6429; avg = 39.9829
We can see that the first 2 threads have more payload than others, need to balance the thread payload here. Normally we can use dynamic scheduling for omp, but this won't fit into pytorch's at::parallel_for which is essentially a static scheduling, so I did manual partitioning here (the logic may be further refined, will do later).
Update on Training optimization
Optimize torch.gather for the classic pyg use case (index tensor is broadcasted), this will be the backward for scatter in training, https://github.com/pytorch/pytorch/pull/87586
When the index tensor is broadcasted along the last dimension, we can parallel on the outer dimension and vectorize on the inner dimension, which is similar to torch.index_select. Compared to scatter, this one is much easier to write.
GCN-Reddit (before)

GCN-Reddit (after)

Sort out the 2nd stage optimization a little bit:
- [x] optimization of
sampled_addmmon SparseCSR: https://github.com/pytorch/pytorch/pull/90978 - [x] enabling of
sampled_addmon SparseCOO (canceled) - [ ] unify
ReduceTypes: GNN would rely on a few operators who have similarReduceTypes, such asScatterReduce,SegmentReduce,SampledReduce,SpmmReduce - [ ] optimization of
segment_reducewithlengthsandoffsets - [ ] migrate
sampled_reduce - [ ] enable multi aggregation
- [ ] New ReduceType of std (use stable alg. welford?)
sampled_addmm COO: this operator is writing to a COO so need to make sure it is coalesced to make it parallel. And coalesced() is very slow right now. I measured on ogbn-products:
- coalesce COO takes 7.2s (and sampled_addmm not included)
- sampled_addmm CSR takes 0.87s