Implement GLU
Open
cognaiger9
opened this issue 1 year ago
•
0 comments
- Add GLU operation with contiguous forward and contiguous backward kernels.
- Add driver and gtest for kernels.
- MIOpen performs better if:
- Input and output tensors are contiguous
- Split dimension is 0
- Number of elements in input tensor are less than 400,000 (forward case) or less than 800,000 (backward case)
Average improvement over ROCm
| type |
fwd |
bwd |
| float16 |
1.3 |
1.56 |
| float |
1.11 |
1.53 |
| bfloat16 |
1.27 |
1.54 |
Detail Benchmark
float16
| op_name |
dtype |
size |
dim |
direction |
ROCm |
MIOpen |
Improvement |
MIOpen vs ROCm |
| GLU |
float16 |
[2 320 4 4 4] |
0 |
fwd |
6816 |
4889 |
1,39 |
1,33 |
| GLU |
float16 |
[2 320 4 4 4] |
0 |
bwd |
7968 |
5066 |
1,57 |
1,55 |
| GLU |
float16 |
[32 64 3 3 3] |
0 |
fwd |
6448 |
4889 |
1,32 |
1,26 |
| GLU |
float16 |
[32 64 3 3 3] |
0 |
bwd |
7776 |
4836 |
1,61 |
1,62 |
| GLU |
float16 |
[64 3 11 11] |
0 |
fwd |
6384 |
4747 |
1,34 |
1,32 |
| GLU |
float16 |
[64 3 11 11] |
0 |
bwd |
7904 |
4853 |
1,63 |
1,58 |
| GLU |
float16 |
[256 256 1 1] |
0 |
fwd |
6176 |
5031 |
1,23 |
1,26 |
| GLU |
float16 |
[256 256 1 1] |
0 |
bwd |
8048 |
5066 |
1,59 |
1,54 |
| GLU |
float16 |
[128 64 7 7] |
0 |
fwd |
7488 |
5777 |
1,30 |
1,16 |
| GLU |
float16 |
[128 64 7 7] |
0 |
bwd |
9344 |
6382 |
1,46 |
1,42 |
| GLU |
float16 |
[64 64 7 7] |
0 |
fwd |
6880 |
5600 |
1,23 |
1,31 |
| GLU |
float16 |
[64 64 7 7] |
0 |
bwd |
8608 |
5778 |
1,49 |
1,48 |
| GLU |
float16 |
[64 32 7 7] |
0 |
fwd |
6192 |
4995 |
1,24 |
1,30 |
| GLU |
float16 |
[64 32 7 7] |
0 |
bwd |
7904 |
5262 |
1,50 |
1,56 |
| GLU |
float16 |
[32 32 7 7] |
0 |
fwd |
6368 |
4907 |
1,30 |
1,26 |
| GLU |
float16 |
[32 32 7 7] |
0 |
bwd |
7936 |
4995 |
1,59 |
1,59 |
float32
| op_name |
dtype |
size |
dim |
direction |
ROCm |
MIOpen |
Improvement |
MIOpen vs ROCm |
| GLU |
float32 |
[2 320 4 4 4] |
0 |
fwd |
5936 |
5013 |
1,18 |
1,33 |
| GLU |
float32 |
[2 320 4 4 4] |
0 |
bwd |
7744 |
5013 |
1,54 |
1,55 |
| GLU |
float32 |
[32 64 3 3 3] |
0 |
fwd |
5408 |
4942 |
1,09 |
1,26 |
| GLU |
float32 |
[32 64 3 3 3] |
0 |
bwd |
7968 |
5262 |
1,51 |
1,62 |
| GLU |
float32 |
[64 3 11 11] |
0 |
fwd |
5376 |
4960 |
1,08 |
1,32 |
| GLU |
float32 |
[64 3 11 11] |
0 |
bwd |
7904 |
4889 |
1,62 |
1,58 |
| GLU |
float32 |
[256 256 1 1] |
0 |
fwd |
5680 |
5191 |
1,09 |
1,26 |
| GLU |
float32 |
[256 256 1 1] |
0 |
bwd |
8064 |
5386 |
1,50 |
1,54 |
| GLU |
float32 |
[128 64 7 7] |
0 |
fwd |
7056 |
6524 |
1,08 |
1,16 |
| GLU |
float32 |
[128 64 7 7] |
0 |
bwd |
10064 |
7182 |
1,40 |
1,42 |
| GLU |
float32 |
[64 64 7 7] |
0 |
fwd |
6128 |
5635 |
1,09 |
1,31 |
| GLU |
float32 |
[64 64 7 7] |
0 |
bwd |
8944 |
5902 |
1,52 |
1,48 |
| GLU |
float32 |
[64 32 7 7] |
0 |
fwd |
5856 |
5155 |
1,14 |
1,30 |
| GLU |
float32 |
[64 32 7 7] |
0 |
bwd |
8320 |
5351 |
1,55 |
1,56 |
| GLU |
float32 |
[32 32 7 7] |
0 |
fwd |
5472 |
4942 |
1,11 |
1,26 |
| GLU |
float32 |
[32 32 7 7] |
0 |
bwd |
8112 |
5031 |
1,61 |
1,59 |
bfloat16
| op_name |
dtype |
size |
dim |
direction |
ROCm |
MIOpen |
Improvement |
MIOpen vs ROCm |
| GLU |
bfloat16 |
[2 320 4 4 4] |
0 |
fwd |
6928 |
5226 |
1,33 |
1,33 |
| GLU |
bfloat16 |
[2 320 4 4 4] |
0 |
bwd |
7776 |
5013 |
1,55 |
1,55 |
| GLU |
bfloat16 |
[32 64 3 3 3] |
0 |
fwd |
6320 |
5031 |
1,26 |
1,26 |
| GLU |
bfloat16 |
[32 64 3 3 3] |
0 |
bwd |
7904 |
4889 |
1,62 |
1,62 |
| GLU |
bfloat16 |
[64 3 11 11] |
0 |
fwd |
6416 |
4871 |
1,32 |
1,32 |
| GLU |
bfloat16 |
[64 3 11 11] |
0 |
bwd |
7792 |
4942 |
1,58 |
1,58 |
| GLU |
bfloat16 |
[256 256 1 1] |
0 |
fwd |
6240 |
4960 |
1,26 |
1,26 |
| GLU |
bfloat16 |
[256 256 1 1] |
0 |
bwd |
8016 |
5191 |
1,54 |
1,54 |
| GLU |
bfloat16 |
[128 64 7 7] |
0 |
fwd |
7392 |
6365 |
1,16 |
1,16 |
| GLU |
bfloat16 |
[128 64 7 7] |
0 |
bwd |
9344 |
6560 |
1,42 |
1,42 |
| GLU |
bfloat16 |
[64 64 7 7] |
0 |
fwd |
7024 |
5369 |
1,31 |
1,31 |
| GLU |
bfloat16 |
[64 64 7 7] |
0 |
bwd |
8576 |
5778 |
1,48 |
1,48 |
| GLU |
bfloat16 |
[64 32 7 7] |
0 |
fwd |
6560 |
5031 |
1,30 |
1,30 |
| GLU |
bfloat16 |
[64 32 7 7] |
0 |
bwd |
8096 |
5191 |
1,56 |
1,56 |
| GLU |
bfloat16 |
[32 32 7 7] |
0 |
fwd |
6384 |
5049 |
1,26 |
1,26 |
| GLU |
bfloat16 |
[32 32 7 7] |
0 |
bwd |
7968 |
5013 |
1,59 |
1,59 |