Implement MaskedFill
Open
cognaiger9
opened this issue 10 months ago
•
0 comments
- Add MaskedFill operation with forward and backward kernels.
- Add driver and gtest for kernels.
- MIOpen performs better if:
- Forward: tensors are not all contiguous
- Backward: tensors are not all contiguous and the number of elements in output is less than 65536
Average improvement over ROCm
| type |
fwd |
bwd |
| float16 |
1.29 |
1.29 |
| float |
1.32 |
1.27 |
| bfloat16 |
1.27 |
1.29 |
Detail Benchmark
float16 (forward)
| op_name |
dtype |
input_size |
value |
contiguous |
direction |
ROCm |
MIOpen |
Improvement |
| MaskedFill |
float16 |
[10 10 10] |
0.5 |
noncont |
fwd |
18032 |
12906 |
1.40 |
| MaskedFill |
float16 |
[10 10 20] |
0.5 |
noncont |
fwd |
14640 |
12230 |
1.20 |
| MaskedFill |
float16 |
[5 10 25] |
0.5 |
noncont |
fwd |
12976 |
12248 |
1.06 |
| MaskedFill |
float16 |
[20 20 20] |
0.5 |
noncont |
fwd |
14416 |
10933 |
1.32 |
| MaskedFill |
float16 |
[40 40 40] |
0.5 |
noncont |
fwd |
12528 |
11608 |
1.08 |
| MaskedFill |
float16 |
[16 16 64] |
0.5 |
noncont |
fwd |
12352 |
10488 |
1.18 |
| MaskedFill |
float16 |
[32 32 32] |
0.5 |
noncont |
fwd |
15248 |
10737 |
1.42 |
| MaskedFill |
float16 |
[256 256 256] |
0.5 |
noncont |
fwd |
1069495 |
820598 |
1.30 |
| MaskedFill |
float16 |
[512 512 512] |
0.5 |
noncont |
fwd |
14215882 |
9130800 |
1.56 |
| MaskedFill |
float16 |
[256 256 512] |
0.5 |
noncont |
fwd |
2785305 |
1857410 |
1.50 |
| MaskedFill |
float16 |
[256 512 512] |
0.5 |
noncont |
fwd |
7190869 |
6010170 |
1.20 |
float32 (forward)
| op_name |
dtype |
input_size |
value |
contiguous |
direction |
ROCm |
MIOpen |
Improvement |
| MaskedFill |
float32 |
[10 10 20] |
0.5 |
noncont |
fwd |
14576 |
12266 |
1.19 |
| MaskedFill |
float32 |
[5 10 25] |
0.5 |
noncont |
fwd |
13072 |
12248 |
1.07 |
| MaskedFill |
float32 |
[20 20 20] |
0.5 |
noncont |
fwd |
13248 |
10844 |
1.22 |
| MaskedFill |
float32 |
[40 40 40] |
0.5 |
noncont |
fwd |
11920 |
11608 |
1.03 |
| MaskedFill |
float32 |
[16 16 64] |
0.5 |
noncont |
fwd |
10720 |
10560 |
1.02 |
| MaskedFill |
float32 |
[32 32 32] |
0.5 |
noncont |
fwd |
11088 |
10880 |
1.02 |
| MaskedFill |
float32 |
[64 64 64] |
0.5 |
noncont |
fwd |
21248 |
19164 |
1.11 |
| MaskedFill |
float32 |
[256 256 256] |
0.5 |
noncont |
fwd |
1290997 |
823229 |
1.57 |
| MaskedFill |
float32 |
[512 512 512] |
0.5 |
noncont |
fwd |
16238058 |
8743070 |
1.86 |
| MaskedFill |
float32 |
[256 256 512] |
0.5 |
noncont |
fwd |
3525363 |
2004449 |
1.76 |
| MaskedFill |
float32 |
[256 512 512] |
0.5 |
noncont |
fwd |
8162845 |
4808530 |
1.70 |
bfloat16 (forward)
| op_name |
dtype |
input_size |
value |
contiguous |
direction |
ROCm |
MIOpen |
Improvement |
| MaskedFill |
bfloat16 |
[10 10 10] |
0.5 |
noncont |
fwd |
17184 |
13351 |
1.29 |
| MaskedFill |
bfloat16 |
[10 10 20] |
0.5 |
noncont |
fwd |
14880 |
12871 |
1.16 |
| MaskedFill |
bfloat16 |
[20 20 20] |
0.5 |
noncont |
fwd |
13984 |
11271 |
1.24 |
| MaskedFill |
bfloat16 |
[40 40 40] |
0.5 |
noncont |
fwd |
12112 |
11430 |
1.06 |
| MaskedFill |
bfloat16 |
[16 16 64] |
0.5 |
noncont |
fwd |
13248 |
10702 |
1.24 |
| MaskedFill |
bfloat16 |
[32 32 32] |
0.5 |
noncont |
fwd |
14816 |
11022 |
1.34 |
| MaskedFill |
bfloat16 |
[256 256 256] |
0.5 |
noncont |
fwd |
938136 |
824153 |
1.14 |
| MaskedFill |
bfloat16 |
[512 512 512] |
0.5 |
noncont |
fwd |
14202954 |
9128760 |
1.56 |
| MaskedFill |
bfloat16 |
[256 256 512] |
0.5 |
noncont |
fwd |
2798921 |
1859950 |
1.50 |
| MaskedFill |
bfloat16 |
[256 512 512] |
0.5 |
noncont |
fwd |
7154437 |
6038530 |
1.18 |
float16 (backward)
| op_name |
dtype |
input_size |
value |
contiguous |
direction |
ROCm |
MIOpen |
Improvement |
| MaskedFill |
float16 |
[10 10 10] |
0.5 |
noncont |
bwd |
19136 |
13902 |
1.38 |
| MaskedFill |
float16 |
[10 10 20] |
0.5 |
noncont |
bwd |
17056 |
13368 |
1.28 |
| MaskedFill |
float16 |
[5 10 25] |
0.5 |
noncont |
bwd |
16176 |
13048 |
1.24 |
| MaskedFill |
float16 |
[20 20 20] |
0.5 |
noncont |
bwd |
16224 |
11235 |
1.44 |
| MaskedFill |
float16 |
[40 40 40] |
0.5 |
noncont |
bwd |
15696 |
12248 |
1.28 |
| MaskedFill |
float16 |
[16 16 64] |
0.5 |
noncont |
bwd |
14880 |
11040 |
1.35 |
| MaskedFill |
float16 |
[32 32 32] |
0.5 |
noncont |
bwd |
16432 |
11644 |
1.41 |
| MaskedFill |
float16 |
[64 64 64] |
0.5 |
noncont |
bwd |
20592 |
20355 |
1.01 |
float32 (backward)
| op_name |
dtype |
input_size |
value |
contiguous |
direction |
ROCm |
MIOpen |
Improvement |
| MaskedFill |
float32 |
[10 10 20] |
0.5 |
noncont |
bwd |
17072 |
12959 |
1.32 |
| MaskedFill |
float32 |
[5 10 25] |
0.5 |
noncont |
bwd |
16768 |
12888 |
1.30 |
| MaskedFill |
float32 |
[20 20 20] |
0.5 |
noncont |
bwd |
15584 |
11235 |
1.39 |
| MaskedFill |
float32 |
[40 40 40] |
0.5 |
noncont |
bwd |
14800 |
12711 |
1.16 |
| MaskedFill |
float32 |
[16 16 64] |
0.5 |
noncont |
bwd |
15168 |
10897 |
1.39 |
| MaskedFill |
float32 |
[32 32 32] |
0.5 |
noncont |
bwd |
14896 |
11608 |
1.28 |
| MaskedFill |
float32 |
[64 64 64] |
0.5 |
noncont |
bwd |
24464 |
22719 |
1.08 |
bfloat16 (backward)
| op_name |
dtype |
input_size |
value |
contiguous |
direction |
ROCm |
MIOpen |
Improvement |
| MaskedFill |
bfloat16 |
[10 10 10] |
0.5 |
noncont |
bwd |
18880 |
13475 |
1.40 |
| MaskedFill |
bfloat16 |
[10 10 20] |
0.5 |
noncont |
bwd |
17424 |
12782 |
1.36 |
| MaskedFill |
bfloat16 |
[5 10 25] |
0.5 |
noncont |
bwd |
15888 |
12728 |
1.25 |
| MaskedFill |
bfloat16 |
[20 20 20] |
0.5 |
noncont |
bwd |
16544 |
11182 |
1.48 |
| MaskedFill |
bfloat16 |
[40 40 40] |
0.5 |
noncont |
bwd |
14784 |
12177 |
1.21 |
| MaskedFill |
bfloat16 |
[16 16 64] |
0.5 |
noncont |
bwd |
15008 |
11164 |
1.34 |
| MaskedFill |
bfloat16 |
[32 32 32] |
0.5 |
noncont |
bwd |
15232 |
11839 |
1.29 |
| MaskedFill |
bfloat16 |
[64 64 64] |
0.5 |
noncont |
bwd |
21184 |
20391 |
1.04 |