Enzyme-JAX icon indicating copy to clipboard operation
Enzyme-JAX copied to clipboard

Generalize Batch-Op

Open avik-pal opened this issue 6 months ago • 3 comments

module @reactant_batched... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func private @"Reactant.TracedUtils.TypeCast{Float32}()_broadcast_scalar"(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"Reactant.TracedUtils.TypeCast{Float32}()_broadcast_scalar_1"(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"+_broadcast_scalar"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @unbatched_fn(%arg0: tensor<5x6xf32>, %arg1: tensor<5xf32>, %arg2: tensor<6xf32>) -> tensor<5xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<5xf32>
    %0 = stablehlo.transpose %cst, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
    %1 = stablehlo.reshape %0 : (tensor<5xf32>) -> tensor<1x5xf32>
    %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<1x5xf32>) -> tensor<5x1xf32>
    %3 = stablehlo.convert %arg0 : tensor<5x6xf32>
    %4 = stablehlo.transpose %arg2, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %5 = stablehlo.reshape %4 : (tensor<6xf32>) -> tensor<1x6xf32>
    %6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<1x6xf32>) -> tensor<6x1xf32>
    %7 = stablehlo.convert %6 : tensor<6x1xf32>
    %8 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<5x6xf32>) -> tensor<5x6xf32>
    %9 = enzyme.batch @"Reactant.TracedUtils.TypeCast{Float32}()_broadcast_scalar"(%8) {batch_shape = array<i64: 5, 6>} : (tensor<5x6xf32>) -> tensor<5x6xf32>
    %10 = stablehlo.convert %9 : tensor<5x6xf32>
    %11 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<6x1xf32>) -> tensor<6x1xf32>
    %12 = enzyme.batch @"Reactant.TracedUtils.TypeCast{Float32}()_broadcast_scalar_1"(%11) {batch_shape = array<i64: 6, 1>} : (tensor<6x1xf32>) -> tensor<6x1xf32>
    %13 = stablehlo.convert %12 : tensor<6x1xf32>
    %14 = stablehlo.dot_general %10, %13, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<5x6xf32>, tensor<6x1xf32>) -> tensor<5x1xf32>
    %15 = stablehlo.transpose %14, dims = [1, 0] : (tensor<5x1xf32>) -> tensor<1x5xf32>
    %16 = stablehlo.reshape %15 : (tensor<1x5xf32>) -> tensor<5xf32>
    %17 = stablehlo.transpose %16, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
    %18 = stablehlo.broadcast_in_dim %17, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
    %19 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
    %20:3 = enzyme.batch @"+_broadcast_scalar"(%18, %19) {batch_shape = array<i64: 5>} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>, tensor<5xf32>)
    %21 = stablehlo.convert %20#0 : tensor<5xf32>
    return %21 : tensor<5xf32>
  }
  func.func @main(%arg0: tensor<6x5xf32> {tf.aliasing_output = 1 : i32}, %arg1: tensor<7x6xf32> {tf.aliasing_output = 2 : i32}, %arg2: tensor<5xf32> {tf.aliasing_output = 3 : i32}) -> (tensor<7x5xf32>, tensor<6x5xf32>, tensor<7x6xf32>, tensor<5xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<6x5xf32>) -> tensor<5x6xf32>
    %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<7x6xf32>) -> tensor<6x7xf32>
    %2 = stablehlo.transpose %arg2, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
    %3 = stablehlo.transpose %1, dims = [1, 0] : (tensor<6x7xf32>) -> tensor<7x6xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<6xf32>
    %4 = enzyme.batch @unbatched_fn(%3) {batch_shape = array<i64: 7>} : (tensor<7x6xf32>) -> tensor<7x5xf32>
    %5 = stablehlo.transpose %4, dims = [1, 0] : (tensor<7x5xf32>) -> tensor<5x7xf32>
    %6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<5x7xf32>) -> tensor<7x5xf32>
    %7 = stablehlo.transpose %0, dims = [1, 0] : (tensor<5x6xf32>) -> tensor<6x5xf32>
    %8 = stablehlo.transpose %1, dims = [1, 0] : (tensor<6x7xf32>) -> tensor<7x6xf32>
    %9 = stablehlo.transpose %2, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
    return %6, %7, %8, %9 : tensor<7x5xf32>, tensor<6x5xf32>, tensor<7x6xf32>, tensor<5xf32>
  }
}
loc("dot_general"("/mnt/software/lux/Reactant.jl/src/Ops.jl":779:0)): error: contracting dimension sizes must match for lhs/rhs

avik-pal avatar Apr 24 '25 21:04 avik-pal