Implement GatherND
Open
cognaiger9
opened this issue 10 months ago
•
0 comments
- Detail of operation (Tensorflow)
- Add GatherND operation with backward kernel.
- Add driver and gtest for kernels.
Average improvement over ROCm
| type |
bwd |
| float16 |
4.14 |
| float |
3.74 |
| bfloat16 |
5.27 |
Detail Benchmark
float16
| op_name |
dtype |
param grad size |
indices size |
contiguous |
direction |
ROCm |
MIOpen |
MIOpen vs ROCm |
| GatherND |
float16 |
[1 512 85742] |
[512 2] |
contiguous |
bwd |
3610262 |
1305290 |
2.77 |
| GatherND |
float16 |
[1 25088 32317] |
[25088 2] |
contiguous |
bwd |
87586133 |
22057100 |
3.97 |
| GatherND |
float16 |
[1 512 4096] |
[512 2] |
contiguous |
bwd |
1393790 |
1239260 |
1.12 |
| GatherND |
float16 |
[1 1024 4096] |
[1024 2] |
contiguous |
bwd |
2649580 |
2244160 |
1.18 |
| GatherND |
float16 |
[1 4096 9192] |
[4096 2] |
contiguous |
bwd |
12533864 |
4564960 |
2.75 |
| GatherND |
float16 |
[1 9192 18384] |
[9192 2] |
contiguous |
bwd |
24809847 |
7643800 |
3.25 |
| GatherND |
float16 |
[1 18384 18384] |
[18384 2] |
contiguous |
bwd |
55369070 |
11581500 |
4.78 |
| GatherND |
float16 |
[2 2 8] |
[96 1] |
contiguous |
bwd |
178046 |
57423 |
3.10 |
| GatherND |
float16 |
[10 20 3] |
[288 2] |
contiguous |
bwd |
222762 |
53814 |
4.14 |
| GatherND |
float16 |
[16 128 256] |
[16 2] |
contiguous |
bwd |
244585 |
28889 |
8.47 |
| GatherND |
float16 |
[256 256 72] |
[16 3] |
contiguous |
bwd |
286021 |
44445 |
6.44 |
| GatherND |
float16 |
[32 1600 64] |
[96 3] |
contiguous |
bwd |
304995 |
39165 |
7.79 |
float32
| op_name |
dtype |
param grad size |
indices size |
contiguous |
direction |
ROCm |
MIOpen |
MIOpen vs ROCm |
| GatherND |
float32 |
[1 512 85742] |
[512 2] |
contiguous |
bwd |
2564567 |
894720 |
2.87 |
| GatherND |
float32 |
[1 25088 32317] |
[25088 2] |
contiguous |
bwd |
60985836 |
18362300 |
3.32 |
| GatherND |
float32 |
[1 512 1024] |
[512 2] |
contiguous |
bwd |
849864 |
398033 |
2.14 |
| GatherND |
float32 |
[1 512 4096] |
[512 2] |
contiguous |
bwd |
1041142 |
428007 |
2.43 |
| GatherND |
float32 |
[1 1024 4096] |
[1024 2] |
contiguous |
bwd |
1900522 |
721967 |
2.63 |
| GatherND |
float32 |
[1 4096 9192] |
[4096 2] |
contiguous |
bwd |
7941847 |
2002540 |
3.97 |
| GatherND |
float32 |
[1 9192 18384] |
[9192 2] |
contiguous |
bwd |
16906965 |
6008330 |
2.81 |
| GatherND |
float32 |
[1 18384 18384] |
[18384 2] |
contiguous |
bwd |
36964247 |
12001700 |
3.08 |
| GatherND |
float32 |
[2 2 8] |
[96 1] |
contiguous |
bwd |
178383 |
38329 |
4.65 |
| GatherND |
float32 |
[16 16 3] |
[864 1] |
contiguous |
bwd |
376348 |
168642 |
2.23 |
| GatherND |
float32 |
[10 20 3] |
[288 2] |
contiguous |
bwd |
210139 |
41102 |
5.11 |
| GatherND |
float32 |
[16 128 256] |
[16 2] |
contiguous |
bwd |
215755 |
29387 |
7.34 |
| GatherND |
float32 |
[256 256 72] |
[16 3] |
contiguous |
bwd |
237562 |
44978 |
5.28 |
| GatherND |
float32 |
[32 1600 64] |
[96 3] |
contiguous |
bwd |
181503 |
39253 |
4.62 |
bfloat16
| op_name |
dtype |
param grad size |
indices size |
contiguous |
direction |
ROCm |
MIOpen |
MIOpen vs ROCm |
| GatherND |
bfloat16 |
[1 512 85742] |
[512 2] |
contiguous |
bwd |
3703774 |
1218990 |
3.04 |
| GatherND |
bfloat16 |
[1 25088 32317] |
[25088 2] |
contiguous |
bwd |
88624446 |
19149400 |
4.63 |
| GatherND |
bfloat16 |
[1 512 4096] |
[512 2] |
contiguous |
bwd |
1435258 |
758537 |
1.89 |
| GatherND |
bfloat16 |
[1 1024 4096] |
[1024 2] |
contiguous |
bwd |
2698582 |
806822 |
3.34 |
| GatherND |
bfloat16 |
[1 4096 9192] |
[4096 2] |
contiguous |
bwd |
12357887 |
1230440 |
10.04 |
| GatherND |
bfloat16 |
[1 9192 18384] |
[9192 2] |
contiguous |
bwd |
26294178 |
4329720 |
6.07 |
| GatherND |
bfloat16 |
[1 18384 18384] |
[18384 2] |
contiguous |
bwd |
55355429 |
8296139 |
6.67 |
| GatherND |
bfloat16 |
[2 2 8] |
[96 1] |
contiguous |
bwd |
176559 |
54667 |
3.23 |
| GatherND |
bfloat16 |
[10 20 3] |
[288 2] |
contiguous |
bwd |
228746 |
54063 |
4.23 |
| GatherND |
bfloat16 |
[16 128 256] |
[16 2] |
contiguous |
bwd |
236586 |
31502 |
7.51 |
| GatherND |
bfloat16 |
[256 256 72] |
[16 3] |
contiguous |
bwd |
266646 |
46951 |
5.68 |
| GatherND |
bfloat16 |
[32 1600 64] |
[96 3] |
contiguous |
bwd |
269126 |
38951 |
6.91 |