xla icon indicating copy to clipboard operation
xla copied to clipboard

Add jax_test configs for shardy and enable it for pjit_test.py and fix any tests.

Open copybara-service[bot] opened this issue 1 year ago • 0 comments

Add jax_test configs for shardy and enable it for pjit_test.py and fix any tests.

Tests fixed include:

  • test_globally_sharded_key_array_8x4_multi_device
    • Issue was in replicate_trailing_dims where an xc.OpSharding was always created. Fixed by creating an equivalent SDY sharding.
  • test_aot_out_info
    • Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in lower_jaxpr_to_module when adding the sdy MeshOp (there won't be any propagation)
  • test_concurrent_pjit
    • In Shardy if there was a tensor dimension of size 0, we'd emit a verification error if the dimension is sharded on an axes. But if the axis is of size 1, then JAX says this is okay. So have shardy assume the same.
  • test_globally_sharded_key_array_result_8x4_single_device
    • This tests adds a WSC when no mesh_shape_tuple exists ("sdy.sharding_constraint"(%8) <{sharding = #sdy.sharding<@mesh, [{?}, {?}, {}]>}>), so we should create a mesh named mesh with a single device id in case it doesn't exist.
  • testLowerCostAnalysis
    • This calls into mlir_module_to_xla_computation which calls its own MLIR parsing function in //third_party/tensorflow/compiler/xla/python/mlir.cc. Needed to register the SDY dialect in it.
  • testShardingConstraintWithArray
    • This calls .compiler_ir(dialect="hlo") which calls PyMlirModuleToXlaComputation which converts the MLIR to HLO, but the Sdy dialect is still inside. Export it before converting it to HLO.

copybara-service[bot] avatar Aug 08 '24 12:08 copybara-service[bot]