Enzyme-JAX icon indicating copy to clipboard operation
Enzyme-JAX copied to clipboard

HLO Canonicalizations Todo list

Open wsmoses opened this issue 1 year ago • 31 comments

To mark which ones we see worth doing, are doing / need to do

cc @ivanradanov @ftynse

  • [x] iota reshape (becomes single iota)
    %195 = stablehlo.iota dim = 0 : tensor<1024xi32>
    %196 = stablehlo.reshape %195 : (tensor<1024xi32>) -> tensor<1x1x1024xi32>
  • [x] reshape of pad (becomes diff pad)
 %175 = stablehlo.pad %174, %148, low = [0, 0, 1024, 0, 0], high = [0, 0, 0, 0, 0], interior = [0, 0, 0, 0, 0] : (tensor<1x3x1024x1x1xf32>, tensor<f32>) -> tensor<1x3x2048x1x1xf32>
    %176 = stablehlo.reshape %175 : (tensor<1x3x2048x1x1xf32>) -> tensor<1x3x2048xf32>
    
  • [x] mul of pad with 0 (becomes pad of mul) https://github.com/EnzymeAD/Enzyme-JAX/commit/44026d4ab0b6353bc13869311b4536c9ff6c4b75
    %175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
    %177 = stablehlo.multiply %176, %112 : tensor<1x3x2048xf32>
  • [x] broadcast of pad (becomes pad of broadcast)
    %175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
    %189 = stablehlo.broadcast_in_dim %177, dims = [0, 2, 4] : (tensor<1x3x2048xf32>) -> tensor<1x1x3x1024x2048xf32>

wsmoses avatar Mar 14 '24 04:03 wsmoses

broadcast of pad (becomes pad of broadcast)

Is this beneficial? Pad will operate on a larger buffer as a result.

ftynse avatar Mar 14 '24 14:03 ftynse

iota reshape (becomes single iota) reshape of pad (becomes diff pad)

I expect this to have little practical effect: reshape should be a metadata operation, I'm not sure it even affects the generated code.

ftynse avatar Mar 14 '24 14:03 ftynse

Yeah I think the question right now is that there's a bunch of these unnecessarily in the way of other ops (which hopefully do have practical impact canonicalizing together), and hopefully these small things would enable them (while in isolation indeed probs not be individually perf-inducing).

I decided to started writing the immediate ones I saw out of the mlir file instead of the bigger ones to do down the line (like batching dots), since had mini IR's for.

wsmoses avatar Mar 14 '24 17:03 wsmoses

I started doing pad/mul as that is clearly beneficial by making mul smaller

ftynse avatar Mar 14 '24 18:03 ftynse

So in the code I saw the actual code was mul(reshape(pad(...))), which I split out into the reshape(pad) and mul(reshape). So I think we'll also need to get that to work end-to-end (I can start it though assuming you haven't).

Ironically you can see I was lazy and didn't even rename the result ops (e.g. reshape(pad) was 175/176 and mul(pad) was 175/177 [in the real code it was using the reshape result])

wsmoses avatar Mar 14 '24 20:03 wsmoses

Is this beneficial? Pad will operate on a larger buffer as a result.

I hypothesize so. In principle a pad of a broadcast should be fusable (e.g. a pad is just doing out[idx] = in bounds ? in[idx2] : pad, and the in[idx2] being a broadcast may get simplified). I suppose the reverse is also the case, but if they don't get fused doing the memset(x, 0) for the bigger buffers at once seems wise. Also like above moving the pads out may help downstream opts.

wsmoses avatar Mar 14 '24 20:03 wsmoses

  • [ ] Reshape transpose (this one im iffy of its utility)
  • [ ] broadcast iota

wsmoses avatar Mar 15 '24 01:03 wsmoses

  • [ ] generalize reshape of concat

actually jk the outer one is hard

    %1384 = stablehlo.concatenate %1382, %1383, dim = 1 : (tensor<1x2x48x4xbf16>, tensor<1x1x48x4xbf16>) -> tensor<1x3x48x4xbf16>
    %1385 = stablehlo.reshape %1384 : (tensor<1x3x48x4xbf16>) -> tensor<1x144x4xbf16>

wsmoses avatar Mar 15 '24 01:03 wsmoses

  • [x] Reduce sum of reshape: -> reduce of unreshaped
    %1819 = stablehlo.reshape %1818 : (tensor<56xf32>) -> tensor<7x8f32>
    %1913 = stablehlo.multiply %1819, %1819 : tensor<7x8xf32>
    %1914 = stablehlo.reduce(%1913 init: %147) applies stablehlo.add across dimensions = [0, 1] : (tensor<7x8xf32>, tensor<f32>) -> tensor<f32>

wsmoses avatar Mar 15 '24 01:03 wsmoses

  • [x] Full Reduce of concat -> concat of reduces
  • [x] Full reduce of transposes -> reduce of operands
  • [x] Reduce sum of convert -> move the convert inside the reduce [jk this is possibly not representable]
  • [ ] Reduce of batched dot (aka matmul) -> dot of reduced operands
        %1205 = stablehlo.dot_general %1204, %778, contracting_dims = [0, 1] x [0, 1], precision = [DEFAULT, DEFAULT] : (tensor<46x123x56xbf16>, tensor<46x1234x32xbf16>) -> tensor<56x32xbf16>
        %1914 = stablehlo.reduce(%1205 init: %147) applies stablehlo.add across dimensions = [0, 1] : (tensor<56x32xf32>, tensor<f32>) -> tensor<f32>
  • [ ] Sum reduce of add -> add of sum reduce
  • [ ] Sum reduce of pad -> sum reduce padded op + ( number of inserted vals ) * padded value [easier if 0 which may just do]
    • [x] Did the zero verison of this

wsmoses avatar Mar 15 '24 01:03 wsmoses

  • [x] convert of pad -> pad of convert [esp for constants]

wsmoses avatar Mar 15 '24 01:03 wsmoses

  • [ ] pad of constants (which may be different, but should do the small size check)

wsmoses avatar Mar 15 '24 01:03 wsmoses

  • [ ] negate(mul(broadcast(x), y) ) -> mul(broadcast(negate(x)), y))

wsmoses avatar Mar 15 '24 01:03 wsmoses

  • [ ] negate(divide(constant, b)) -> divide(constant2, b) [if assuming no infinite values]

wsmoses avatar Mar 15 '24 01:03 wsmoses

  • [x] pad of same values [also taking into acct negative vs positive zero]
    %147 = stablehlo.constant dense<0.000000e+00> 
    %81 = stablehlo.constant dense<-0.000000e+00> 
    %1185 = stablehlo.pad %81, %147, low = [...], high = [...], interior = [...]

wsmoses avatar Mar 15 '24 03:03 wsmoses

  • [x] dot(pad(zero from up to i and j, axis=contract, x), y) -> dot(x, y[i:j])

wsmoses avatar Mar 15 '24 03:03 wsmoses

  • [ ] mul(bcast x, mul(bcast y, z)) -> mul(z, bcast(mul x, y))

  • [ ] distributive property full reduce add mul(z, bcast y) -> mul(full reduce add z, y)

wsmoses avatar Mar 15 '24 06:03 wsmoses

  • [x] slice of broadcast -> broadcast
  • [x] slice of transpose -> transpose of slice

wsmoses avatar Mar 16 '24 01:03 wsmoses

  • [ ] slice of convert -> convert of slice. (perhaps generalize to any unary op, if only user)

wsmoses avatar Mar 16 '24 02:03 wsmoses

  • [x] slice of reshape -> reshape of slice (by @ftynse in https://github.com/EnzymeAD/Enzyme-JAX/pull/60)
  • [x] transpose of dot general A, B -> dot general B, A [where applicable] Also should handle a convert in the middle
  • [x] slice of binop(a, b) [if slice is only user] -> binop(slice(a), slice(b)) (#68 https://github.com/EnzymeAD/Enzyme-JAX/commit/7475dd8fb542a3971111c97378b5f69d59fb4631)

wsmoses avatar Mar 16 '24 18:03 wsmoses

  • [ ] (partial) sum reduce of broadcast

wsmoses avatar Mar 16 '24 18:03 wsmoses

  • [x] convert of convert
  • [x] generalize transpose transpose to transpose convert transpose

wsmoses avatar Mar 17 '24 05:03 wsmoses

  • [x] dot general transpose(A), B or A, transpose(B) -> dot general A, B (where applicable)

wsmoses avatar Mar 17 '24 06:03 wsmoses

(partial) sum reduce of broadcast

Do you have an example of this? There is a bunch of cases that are already supported https://github.com/EnzymeAD/Enzyme-JAX/blob/main/test/lit_tests/broadcastreduce.mlir

ftynse avatar Mar 18 '24 14:03 ftynse

  • [ ] Select of Pad

wsmoses avatar Mar 22 '24 02:03 wsmoses

(partial) sum reduce of broadcast

Do you have an example of this? There is a bunch of cases that are already supported https://github.com/EnzymeAD/Enzyme-JAX/blob/main/test/lit_tests/broadcastreduce.mlir

Unfortunately not presently, but I'll post when it comes up again.

wsmoses avatar Mar 22 '24 02:03 wsmoses

  • [x] Pad of Pad (#67 )

wsmoses avatar Mar 22 '24 02:03 wsmoses

  • [x] slice of dot general
 %1205 = stablehlo.dot_general %1181, %482, batching_dims = [0, 2] x [0, 1], contracting_dims = [3, 1] x [2, 3], precision = [DEFAULT, DEFAULT] : (tensor<1x16x1x20x100xbf16>, tensor<1x1x20x16x123xbf16>) -> tensor<1x1x100x123xbf16>
 %1208 = stablehlo.slice %1205 [0:1, 0:1, 75:100, 0:256] : (tensor<1x1x100x123xbf16>) -> tensor<1x1x25x123xbf16>

wsmoses avatar Mar 29 '24 16:03 wsmoses

  • [ ] reshapce(concat(constants or reshapes...)) -> concat(constants or reshapes...)

wsmoses avatar Apr 13 '24 01:04 wsmoses

Generalize pad propagations to work with an interstitial reshape

  • [x] PadPad. (need PadReshapePad)
  • [ ] BinopPadtoConcat (need BinopReshapePadtoConcat)
  • [ ] ConcatPad (need ConcatReshapePad)
  • [ ] ReducePad
  • [ ] BroadcastPad
  • [x] MulZeroPad
  • [x] DivZeroPad
  • [x] BinopConstPad
  • [ ] BinopBinopPadPad
  • [ ] AddPadPadtoConcat
  • [ ] UnaryPadPush
  • [ ] TransposePad
  • [x] PadDotGeneral

wsmoses avatar Apr 14 '24 01:04 wsmoses