Enzyme-JAX
Enzyme-JAX copied to clipboard
HLO Canonicalizations Todo list
To mark which ones we see worth doing, are doing / need to do
cc @ivanradanov @ftynse
- [x] iota reshape (becomes single iota)
%195 = stablehlo.iota dim = 0 : tensor<1024xi32>
%196 = stablehlo.reshape %195 : (tensor<1024xi32>) -> tensor<1x1x1024xi32>
- [x] reshape of pad (becomes diff pad)
%175 = stablehlo.pad %174, %148, low = [0, 0, 1024, 0, 0], high = [0, 0, 0, 0, 0], interior = [0, 0, 0, 0, 0] : (tensor<1x3x1024x1x1xf32>, tensor<f32>) -> tensor<1x3x2048x1x1xf32>
%176 = stablehlo.reshape %175 : (tensor<1x3x2048x1x1xf32>) -> tensor<1x3x2048xf32>
- [x] mul of pad with 0 (becomes pad of mul) https://github.com/EnzymeAD/Enzyme-JAX/commit/44026d4ab0b6353bc13869311b4536c9ff6c4b75
%175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
%177 = stablehlo.multiply %176, %112 : tensor<1x3x2048xf32>
- [x] broadcast of pad (becomes pad of broadcast)
%175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
%189 = stablehlo.broadcast_in_dim %177, dims = [0, 2, 4] : (tensor<1x3x2048xf32>) -> tensor<1x1x3x1024x2048xf32>
broadcast of pad (becomes pad of broadcast)
Is this beneficial? Pad will operate on a larger buffer as a result.
iota reshape (becomes single iota) reshape of pad (becomes diff pad)
I expect this to have little practical effect: reshape should be a metadata operation, I'm not sure it even affects the generated code.
Yeah I think the question right now is that there's a bunch of these unnecessarily in the way of other ops (which hopefully do have practical impact canonicalizing together), and hopefully these small things would enable them (while in isolation indeed probs not be individually perf-inducing).
I decided to started writing the immediate ones I saw out of the mlir file instead of the bigger ones to do down the line (like batching dots), since had mini IR's for.
I started doing pad/mul as that is clearly beneficial by making mul smaller
So in the code I saw the actual code was mul(reshape(pad(...))), which I split out into the reshape(pad) and mul(reshape). So I think we'll also need to get that to work end-to-end (I can start it though assuming you haven't).
Ironically you can see I was lazy and didn't even rename the result ops (e.g. reshape(pad) was 175/176 and mul(pad) was 175/177 [in the real code it was using the reshape result])
Is this beneficial? Pad will operate on a larger buffer as a result.
I hypothesize so. In principle a pad of a broadcast should be fusable (e.g. a pad is just doing out[idx] = in bounds ? in[idx2] : pad, and the in[idx2] being a broadcast may get simplified). I suppose the reverse is also the case, but if they don't get fused doing the memset(x, 0) for the bigger buffers at once seems wise. Also like above moving the pads out may help downstream opts.
- [ ] Reshape transpose (this one im iffy of its utility)
- [ ] broadcast iota
- [ ] generalize reshape of concat
actually jk the outer one is hard
%1384 = stablehlo.concatenate %1382, %1383, dim = 1 : (tensor<1x2x48x4xbf16>, tensor<1x1x48x4xbf16>) -> tensor<1x3x48x4xbf16>
%1385 = stablehlo.reshape %1384 : (tensor<1x3x48x4xbf16>) -> tensor<1x144x4xbf16>
- [x] Reduce sum of reshape: -> reduce of unreshaped
%1819 = stablehlo.reshape %1818 : (tensor<56xf32>) -> tensor<7x8f32>
%1913 = stablehlo.multiply %1819, %1819 : tensor<7x8xf32>
%1914 = stablehlo.reduce(%1913 init: %147) applies stablehlo.add across dimensions = [0, 1] : (tensor<7x8xf32>, tensor<f32>) -> tensor<f32>
- [x] Full Reduce of concat -> concat of reduces
- [x] Full reduce of transposes -> reduce of operands
- [x] Reduce sum of convert -> move the convert inside the reduce [jk this is possibly not representable]
- [ ] Reduce of batched dot (aka matmul) -> dot of reduced operands
%1205 = stablehlo.dot_general %1204, %778, contracting_dims = [0, 1] x [0, 1], precision = [DEFAULT, DEFAULT] : (tensor<46x123x56xbf16>, tensor<46x1234x32xbf16>) -> tensor<56x32xbf16>
%1914 = stablehlo.reduce(%1205 init: %147) applies stablehlo.add across dimensions = [0, 1] : (tensor<56x32xf32>, tensor<f32>) -> tensor<f32>
- [ ] Sum reduce of add -> add of sum reduce
- [ ] Sum reduce of pad -> sum reduce padded op + ( number of inserted vals ) * padded value [easier if 0 which may just do]
- [x] Did the zero verison of this
- [x] convert of pad -> pad of convert [esp for constants]
- [ ] pad of constants (which may be different, but should do the small size check)
- [ ] negate(mul(broadcast(x), y) ) -> mul(broadcast(negate(x)), y))
- [ ] negate(divide(constant, b)) -> divide(constant2, b) [if assuming no infinite values]
- [x] pad of same values [also taking into acct negative vs positive zero]
%147 = stablehlo.constant dense<0.000000e+00>
%81 = stablehlo.constant dense<-0.000000e+00>
%1185 = stablehlo.pad %81, %147, low = [...], high = [...], interior = [...]
- [x] dot(pad(zero from up to i and j, axis=contract, x), y) -> dot(x, y[i:j])
-
[ ] mul(bcast x, mul(bcast y, z)) -> mul(z, bcast(mul x, y))
-
[ ] distributive property full reduce add mul(z, bcast y) -> mul(full reduce add z, y)
- [x] slice of broadcast -> broadcast
- [x] slice of transpose -> transpose of slice
- [ ] slice of convert -> convert of slice. (perhaps generalize to any unary op, if only user)
- [x] slice of reshape -> reshape of slice (by @ftynse in https://github.com/EnzymeAD/Enzyme-JAX/pull/60)
- [x] transpose of dot general A, B -> dot general B, A [where applicable] Also should handle a convert in the middle
- [x] slice of binop(a, b) [if slice is only user] -> binop(slice(a), slice(b)) (#68 https://github.com/EnzymeAD/Enzyme-JAX/commit/7475dd8fb542a3971111c97378b5f69d59fb4631)
- [ ] (partial) sum reduce of broadcast
- [x] convert of convert
- [x] generalize transpose transpose to transpose convert transpose
- [x] dot general transpose(A), B or A, transpose(B) -> dot general A, B (where applicable)
(partial) sum reduce of broadcast
Do you have an example of this? There is a bunch of cases that are already supported https://github.com/EnzymeAD/Enzyme-JAX/blob/main/test/lit_tests/broadcastreduce.mlir
- [ ] Select of Pad
(partial) sum reduce of broadcast
Do you have an example of this? There is a bunch of cases that are already supported https://github.com/EnzymeAD/Enzyme-JAX/blob/main/test/lit_tests/broadcastreduce.mlir
Unfortunately not presently, but I'll post when it comes up again.
- [x] Pad of Pad (#67 )
- [x] slice of dot general
%1205 = stablehlo.dot_general %1181, %482, batching_dims = [0, 2] x [0, 1], contracting_dims = [3, 1] x [2, 3], precision = [DEFAULT, DEFAULT] : (tensor<1x16x1x20x100xbf16>, tensor<1x1x20x16x123xbf16>) -> tensor<1x1x100x123xbf16>
%1208 = stablehlo.slice %1205 [0:1, 0:1, 75:100, 0:256] : (tensor<1x1x100x123xbf16>) -> tensor<1x1x25x123xbf16>
- [ ] reshapce(concat(constants or reshapes...)) -> concat(constants or reshapes...)
Generalize pad propagations to work with an interstitial reshape
- [x] PadPad. (need PadReshapePad)
- [ ] BinopPadtoConcat (need BinopReshapePadtoConcat)
- [ ] ConcatPad (need ConcatReshapePad)
- [ ] ReducePad
- [ ] BroadcastPad
- [x] MulZeroPad
- [x] DivZeroPad
- [x] BinopConstPad
- [ ] BinopBinopPadPad
- [ ] AddPadPadtoConcat
- [ ] UnaryPadPush
- [ ] TransposePad
- [x] PadDotGeneral