[Bug] Half Precision SpMM produce incorrect result
🐛 Bug
The Half Precision SpMM is buggy.
To Reproduce
I tried to apply SpMM on both fp16 feature and fp32 feature:
import dgl
import torch as th
from ogb.nodeproppred import DglNodePropPredDataset
arxiv = DglNodePropPredDataset(name="ogbn-arxiv")
g = arxiv[0][0].int().to(0)
feat_fp32 = th.rand((g.num_src_nodes(), 32)).to(0)
feat_fp16 = feat_fp32.half()
res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)
res_fp16 = dgl.ops.copy_u_sum(g, feat_fp16)
print(res_fp32[1353], res_fp16[1353])
and there is a huge discrepancy in outputs corresponding to different data types:
tensor([6608.9536, 6597.6616, 6586.2480, 6575.7998, 6614.4141, 6581.5703,
6586.3008, 6545.8574, 6618.2349, 6502.8057, 6572.4966, 6600.5298,
6564.5327, 6536.4312, 6594.0298, 6596.3115, 6627.1670, 6544.6133,
6512.9751, 6500.8130, 6557.1406, 6510.1670, 6592.9810, 6569.1846,
6563.0474, 6555.4189, 6544.5234, 6580.5405, 6584.4556, 6578.9912,
6587.9751, 6617.9995], device='cuda:0') tensor([2048., 2048., 2048., 2048., 2048., 2048., 2048., 2048., 2048., 2048.,
2048., 2048., 2048., 2048., 2048., 2048., 2048., 2048., 2048., 2048.,
2048., 2048., 2048., 2048., 2048., 2048., 2048., 2048., 2048., 2048.,
2048., 2048.], device='cuda:0', dtype=torch.float16)
Expected behavior
As I provide similar inputs, the outputs should be close.
Environment
- DGL Version (e.g., 1.0): 0.9.0
- Backend Library & Version (e.g., PyTorch 0.4.1, MXNet/Gluon 1.3): PyTorch 1.12
- OS (e.g., Linux): Fedora 36
- How you installed DGL (
conda,pip, source): pip - Python version: 3.10.4
- CUDA/cuDNN version (if applicable): 11.6
- GPU models and configuration (e.g. V100): RTX 3070
I can confirm the problem still exists after I upgrade dgl to v0.9.1.
@chang-l @TristonC Could you guys help on this?
The precision interval for fp16 between 2048 and 4096 is 2. https://en.wikipedia.org/wiki/Half-precision_floating-point_format.
Considering elements in feat_fp16 are between 0 and 1, they will be ignored due to the round-off error.
For example,
In [1]: import torch
In [2]: a = torch.tensor([2048.], dtype=torch.float16)
In [3]: b = torch.tensor([0.9], dtype=torch.float16)
In [4]: a + b
Out[4]: tensor([2048.], dtype=torch.float16)
In [5]: c = torch.tensor([1.9], dtype=torch.float16)
In [6]: a + c
Out[6]: tensor([2050.], dtype=torch.float16)
The precision interval for fp16 between 2048 and 4096 is 2. https://en.wikipedia.org/wiki/Half-precision_floating-point_format. Considering elements in
feat_fp16are between 0 and 1, they will be ignored due to the round-off error.For example,
In [1]: import torch In [2]: a = torch.tensor([2048.], dtype=torch.float16) In [3]: b = torch.tensor([0.9], dtype=torch.float16) In [4]: a + b Out[4]: tensor([2048.], dtype=torch.float16) In [5]: c = torch.tensor([1.9], dtype=torch.float16) In [6]: a + c Out[6]: tensor([2050.], dtype=torch.float16)
Thanks for your time in investigating this issue. Now I understand where the difference comes from. I think one possible approach to mitigate the round-off error is to accumulate by segments instead of accumulating by elements:
$$ \sum_{i=0}^{N/32 - 1} \left(\sum_{j=0}^{31} x_{i * 32 + j}\right)$$
Where the inner $$\sum_{j=0}^{31} x_{i * 32 + j}$$ tends to be greater than $2$, however people can always construct counter examples to make the trick fail.
Another (more reasonable) fix is to use a float32 accumulator inside the kernel, and convert them back to fp16 after aggregation is finished if the output type is fp16.