Disable offloading attention like kernels to rocMLIR if there is no wmma/mfma for requested dtype
DOR (Definition of Ready)
Ready
Description
rocMLIR doesn't support attention like kernels. eg. GEG, Attention, CEG if there is no associated wmma/mfma instruction for that dtype on the hardware.
e.g. using Fp32 on navi3x for the attention would crash.
DOD (Definition of Done)
Attention like kernels are not being generated for unsupported dtypes.
Right now there isnt an easy way to do this. If we skip attention in fuse_mlir there will be left over group op that wont get lowered. We probably need to make a pass inline_group to inline the attention afterwards, but then we will need to run fuse_pointwise_reduce pass again. This will need a tweak so fuse_pointwise uses a different name to avoid name conflicts on the second run. Probably move the counter to the pass manager.
So these would be the set of subtasks to get this to work:
- [ ] Add a
inline_grouppass to inline group operators - [ ] Tweak
fuse_pointwiseto avoid name conflicts - [ ] Add another
fuse_pointwise_reducepass
Any progress on this one ?
It is causing failures on rocMLIR CI now.
root@18441c05b42d:~/repo/AMDMIGraphX/build(develop)# MIGRAPHX_MLIR_USE_SPECIFIC_OPS="convolution,fused,dot,attention" MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1 MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION=1 MIGRAPHX_MLIR_ENABLE_SPLITK=1 MIGRAPHX_ENABLE_EXTRA_MLIR=1 MIGRAPHX_DISABLE_LAYERNORM_FUSION=1 MIGRAPHX_ENABLE_SPLIT_REDUCE=1 ./bin/migraphx-driver compile --onnx ~/bert_base_cased_1.onnx --fill1 input_ids --input-dim @input_ids 1 384 |
Running above causes MIGraphX to generate -t F32 attention kernel which fails to compile on Navi4x.