Enzyme-JAX
Enzyme-JAX copied to clipboard
Pull repeated ops from conditional branches
%3:2 = "stablehlo.if"(%2) ({
%14 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<i64>) -> tensor<4x2xi64>
%15 = stablehlo.convert %14 : (tensor<4x2xi64>) -> tensor<4x2xf32>
%16 = stablehlo.divide %1, %15 : tensor<4x2xf32>
stablehlo.return %16, %16 : tensor<4x2xf32>, tensor<4x2xf32>
}, {
%14 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<i64>) -> tensor<4x2xi64>
%15 = stablehlo.convert %14 : (tensor<4x2xi64>) -> tensor<4x2xf32>
%16 = stablehlo.divide %1, %15 : tensor<4x2xf32>
%17 = stablehlo.add %0, %16 : tensor<4x2xf32>
stablehlo.return %17, %17 : tensor<4x2xf32>, tensor<4x2xf32>
}) : (tensor<i1>) -> (tensor<4x2xf32>, tensor<4x2xf32>)
We should be able to lift
%14 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<i64>) -> tensor<4x2xi64>
%15 = stablehlo.convert %14 : (tensor<4x2xi64>) -> tensor<4x2xf32>
%16 = stablehlo.divide %1, %15 : tensor<4x2xf32>
into the outer block