xla
xla copied to clipboard
Preserve HLO shardings on calls and non-entry functions.
Preserve HLO shardings on calls and non-entry functions.
XLA doesn't inline call instructions and their functions if the call instruction has backend_config. As such in Shardy we do the same. GSPMD propagation also adds shardings on the call instruction and function inputs/outputs.
However, when given this module:
module @jit_f attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = true, mhlo.num_partitions = 4 : i32, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"} loc("x")) -> (tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"}) {
%0 = call @called_computation(%arg0) {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, mhlo.sharding = "{devices=[2,2]<=[4]}"} : (tensor<8x2xi32>) -> tensor<8x2xi32> loc(#loc33)
%1 = mhlo.custom_call @MoveToHost(%0) {backend_config = "", mhlo.sharding = "{devices=[2,2]<=[4]}"} : (tensor<8x2xi32>) -> tensor<8x2xi32> loc(#loc12)
return %1 : tensor<8x2xi32> loc(#loc)
} loc(#loc)
func.func private @called_computation(%arg0: tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"} loc("param_0")) -> (tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"}) {
%0 = mhlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, mhlo.sharding = "{devices=[2,2]<=[4]}"} : tensor<8x2xi32> loc(#loc33)
return %0 : tensor<8x2xi32>
}
}
HLO conversion removes the call instruction and @called_computation shardings. This PR preserves them.