Introduce MLIR transform dialect to BladeDISC
We'll start to explore using MLIR transform dialect to do codegen for (fused) compute-intensive pattern. The initial target is to support gemm codegen on ARM platform to address the dynamic shape problem of Arm Compute Library.
The initial plan is:
- [x] Step 1, enhance the fusion decision pass. We’ll add a new fusion kind
kTransformfor the transform-based fusion pattern. - [x] Step 2, lower the lmhlo fusion op to linalg on tensor.
- [x] Step 3, transform the linalg computation to loops using transform dialect.
- [x] Step 4, refined the transformed loop to make it suitable for BladeDISC runtime.
- [x] Step 5, add a new pass to the disc pass pipeline to drive the above process.
- [x] Step 6, weight pre-packing support
- [x] add
disc_linalg.multi_level_packop, used for doing packing. - [x] add
transform.disc.cache_readtransform op, relying ondisc_linalg.multi_level_packop. - [x] add folding support for
disc_linalg.multi_level_pack. - [x] lower
disc_linalg.multi_level_packto loop if it can not be folded. - [x] fuse const weight op into the
kTransformfusion pattern, lower it to linalg and then schedule it.
- [x] add
- [x] Step 7, assign a default schedule for each
kTransformpattern. - [x] Step 8, schedule selection logic injection
- [x] Step 9, initial model level testing: bert (albert).
- [x] Step 10, support nt, tn, tt format GEMM.
- [ ] Step 11, support batch matmul
- [x] Step 12, support GEMM epilogue fusion.
- [ ] Step 13, performance optimization
e2e model test on: Bert Base (TF) and Albert (PyTorch), on g6r, using single thread. Note that we only have one default schedule for all shape and the schedule is known to be less performant when n or k is large, thus the initial performance is supposed to be improved when we support schedule selection logic.
Bert Base (TF)
| input | TF 2.8(s) | DISC-ACL(s) | DISC-Transform(s) | speedup (DISC-transform / DISC-ACL) |
|---|---|---|---|---|
| (1, 128) | 0.742 | 0.638 | 0.656 | 97.3% |
| (2, 128) | 1.41 | 1.24 | 1.27 | 97.6% |
| (4, 128) | 2.85 | 2.36 | 2.55 | 92.5% |
| (8, 128) | 5.84 | 4.68 | 5.07 | 92.3% |
| (16, 128) | 11.9 | 9.55 | 10.2 | 93.6% |
Albert (PyTorch)
| input | TorchScript | OnnxRuntime | DISC-ACL | DISC-Transform |
|---|---|---|---|---|
| (2, 12) | 0.197 | 0.140 | 0.117 | 0.139 |
some sharing doc:
https://bladedisc.oss-cn-hangzhou.aliyuncs.com/docs/transform-dialect-based-codegen-in-bladedisc.pdf