Enzyme-JAX icon indicating copy to clipboard operation
Enzyme-JAX copied to clipboard

Pull repeated ops from conditional branches

Open avik-pal opened this issue 8 months ago • 4 comments

%3:2 = "stablehlo.if"(%2) ({
      %14 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<i64>) -> tensor<4x2xi64>
      %15 = stablehlo.convert %14 : (tensor<4x2xi64>) -> tensor<4x2xf32>
      %16 = stablehlo.divide %1, %15 : tensor<4x2xf32>
      stablehlo.return %16, %16 : tensor<4x2xf32>, tensor<4x2xf32>
    }, {
      %14 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<i64>) -> tensor<4x2xi64>
      %15 = stablehlo.convert %14 : (tensor<4x2xi64>) -> tensor<4x2xf32>
      %16 = stablehlo.divide %1, %15 : tensor<4x2xf32>
      %17 = stablehlo.add %0, %16 : tensor<4x2xf32>
      stablehlo.return %17, %17 : tensor<4x2xf32>, tensor<4x2xf32>
    }) : (tensor<i1>) -> (tensor<4x2xf32>, tensor<4x2xf32>)

We should be able to lift

      %14 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<i64>) -> tensor<4x2xi64>
      %15 = stablehlo.convert %14 : (tensor<4x2xi64>) -> tensor<4x2xf32>
      %16 = stablehlo.divide %1, %15 : tensor<4x2xf32>

into the outer block

avik-pal avatar Mar 08 '25 17:03 avik-pal