Implement Softmax
Open
cognaiger9
opened this issue 10 months ago
•
0 comments
- Add Softmax operation with forward and backward kernels.
- Add driver and gtest for kernels.
Average improvement over ROCm
| type |
fwd |
bwd |
| float16 |
1.23 |
1.34 |
| float |
1.42 |
1.7 |
| bfloat16 |
1.36 |
1.4 |
Detail Benchmark
float16
| op_name |
dtype |
size |
dim |
direction |
ROCm |
MIOpen |
Improvement |
| softmax |
float16 |
[40 12 512 512] |
0 |
fwd |
2387296 |
2068530 |
1.15 |
| softmax |
float16 |
[40 12 512 512] |
3 |
bwd |
1621787 |
1258830 |
1.29 |
| softmax |
float16 |
[16 21 512 512] |
3 |
bwd |
1140924 |
882061 |
1.29 |
| softmax |
float16 |
[24 21 512 512] |
2 |
fwd |
2461811 |
2095470 |
1.17 |
| softmax |
float16 |
[24 21 512 512] |
2 |
bwd |
2809512 |
2517230 |
1.12 |
| softmax |
float16 |
[24 21 512 512] |
3 |
bwd |
1664764 |
1308200 |
1.27 |
| softmax |
float16 |
[64 21 254 333] |
1 |
fwd |
2262283 |
1630830 |
1.39 |
| softmax |
float16 |
[64 21 254 333] |
1 |
bwd |
2415334 |
1312320 |
1.84 |
| softmax |
float16 |
[64 21 274 275] |
0 |
fwd |
1874841 |
1521130 |
1.23 |
| softmax |
float16 |
[64 21 274 275] |
0 |
bwd |
2286461 |
1916420 |
1.19 |
| softmax |
float16 |
[16 12 512 512] |
2 |
fwd |
1093984 |
857349 |
1.28 |
| softmax |
float16 |
[16 12 512 512] |
2 |
bwd |
1345465 |
1014480 |
1.33 |
| softmax |
float16 |
[16 12 512 512] |
3 |
bwd |
698140 |
514815 |
1.36 |
| softmax |
float16 |
[16 12 1024 1024] |
0 |
fwd |
3340209 |
3115000 |
1.07 |
| softmax |
float16 |
[16 12 1024 1024] |
3 |
bwd |
2690629 |
1914170 |
1.41 |
| softmax |
float16 |
[64 21 273 322] |
1 |
fwd |
2308081 |
1705720 |
1.35 |
float32
| op_name |
dtype |
size |
dim |
direction |
ROCm |
MIOpen |
Improvement |
| softmax |
float32 |
[40 12 512 512] |
0 |
bwd |
3387696 |
2342360 |
1.45 |
| softmax |
float32 |
[40 12 512 512] |
3 |
bwd |
2746646 |
1366120 |
2.01 |
| softmax |
float32 |
[16 21 512 512] |
0 |
bwd |
2182362 |
1269290 |
1.72 |
| softmax |
float32 |
[16 21 512 512] |
3 |
bwd |
1845381 |
962271 |
1.92 |
| softmax |
float32 |
[24 21 512 512] |
2 |
fwd |
2377126 |
2025380 |
1.17 |
| softmax |
float32 |
[24 21 512 512] |
2 |
bwd |
3644878 |
2351230 |
1.55 |
| softmax |
float32 |
[24 21 512 512] |
3 |
bwd |
2748875 |
1362150 |
2.02 |
| softmax |
float32 |
[64 21 254 333] |
1 |
fwd |
3117586 |
1786060 |
1.75 |
| softmax |
float32 |
[64 21 254 333] |
1 |
bwd |
3821724 |
2345770 |
1.63 |
| softmax |
float32 |
[64 21 254 333] |
3 |
bwd |
2558036 |
1624640 |
1.57 |
| softmax |
float32 |
[64 21 274 275] |
0 |
fwd |
2797934 |
1945240 |
1.44 |
| softmax |
float32 |
[64 21 274 275] |
0 |
bwd |
3651317 |
2577740 |
1.42 |
| softmax |
float32 |
[64 21 274 275] |
3 |
bwd |
2365995 |
1516970 |
1.56 |
| softmax |
float32 |
[16 12 512 512] |
2 |
fwd |
1007843 |
777760 |
1.30 |
| softmax |
float32 |
[16 12 512 512] |
2 |
bwd |
1470582 |
890450 |
1.65 |
| softmax |
float32 |
[16 12 512 512] |
3 |
bwd |
1099249 |
560503 |
1.96 |
bfloat16
| op_name |
dtype |
size |
dim |
direction |
ROCm |
MIOpen |
Improvement |
| softmax |
bfloat16 |
[40 12 512 512] |
0 |
fwd |
2389249 |
2089470 |
1.14 |
| softmax |
bfloat16 |
[40 12 512 512] |
3 |
fwd |
2938001 |
1847650 |
1.59 |
| softmax |
bfloat16 |
[40 12 512 512] |
3 |
bwd |
1615068 |
1284590 |
1.26 |
| softmax |
bfloat16 |
[16 21 512 512] |
3 |
bwd |
1098381 |
899471 |
1.22 |
| softmax |
bfloat16 |
[24 21 512 512] |
2 |
fwd |
2492450 |
2133930 |
1.17 |
| softmax |
bfloat16 |
[24 21 512 512] |
3 |
bwd |
1625726 |
1334160 |
1.22 |
| softmax |
bfloat16 |
[64 21 254 333] |
1 |
fwd |
2264252 |
1673840 |
1.35 |
| softmax |
bfloat16 |
[64 21 254 333] |
1 |
bwd |
2526948 |
1325550 |
1.91 |
| softmax |
bfloat16 |
[64 21 274 275] |
0 |
fwd |
1870393 |
1531300 |
1.22 |
| softmax |
bfloat16 |
[64 21 274 275] |
0 |
bwd |
2306828 |
1920690 |
1.20 |
| softmax |
bfloat16 |
[16 12 512 512] |
2 |
fwd |
1102672 |
877420 |
1.26 |
| softmax |
bfloat16 |
[16 12 512 512] |
2 |
bwd |
1302154 |
1014289 |
1.28 |
| softmax |
bfloat16 |
[16 12 512 512] |
3 |
bwd |
701804 |
525358 |
1.34 |
| softmax |
bfloat16 |
[16 12 1024 1024] |
3 |
bwd |
2645928 |
1967850 |
1.34 |
| softmax |
bfloat16 |
[64 21 273 322] |
1 |
fwd |
2311698 |
1748050 |
1.32 |
| softmax |
bfloat16 |
[64 21 273 322] |
1 |
bwd |
2661480 |
1470710 |
1.81 |