xla icon indicating copy to clipboard operation
xla copied to clipboard

Preserve HLO shardings on calls and non-entry functions.

Open copybara-service[bot] opened this issue 1 year ago • 0 comments

Preserve HLO shardings on calls and non-entry functions.

XLA doesn't inline call instructions and their functions if the call instruction has backend_config. As such in Shardy we do the same. GSPMD propagation also adds shardings on the call instruction and function inputs/outputs.

However, when given this module:

module @jit_f attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = true, mhlo.num_partitions = 4 : i32, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"} loc("x")) -> (tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"}) {
    %0 = call @called_computation(%arg0) {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, mhlo.sharding = "{devices=[2,2]<=[4]}"} : (tensor<8x2xi32>) -> tensor<8x2xi32> loc(#loc33)
    %1 = mhlo.custom_call @MoveToHost(%0) {backend_config = "", mhlo.sharding = "{devices=[2,2]<=[4]}"} : (tensor<8x2xi32>) -> tensor<8x2xi32> loc(#loc12)
    return %1 : tensor<8x2xi32> loc(#loc)
  } loc(#loc)
  func.func private @called_computation(%arg0: tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"} loc("param_0")) -> (tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"}) {
    %0 = mhlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, mhlo.sharding = "{devices=[2,2]<=[4]}"} : tensor<8x2xi32> loc(#loc33)
    return %0 : tensor<8x2xi32>
  }
}

HLO conversion removes the call instruction and @called_computation shardings. This PR preserves them.

copybara-service[bot] avatar Sep 30 '24 17:09 copybara-service[bot]