xla
xla copied to clipboard
Add jax_test configs for shardy and enable it for pjit_test.py and fix any tests.
Add jax_test configs for shardy and enable it for pjit_test.py and fix any tests.
Tests fixed include:
test_globally_sharded_key_array_8x4_multi_device- Issue was in
replicate_trailing_dimswhere anxc.OpShardingwas always created. Fixed by creating an equivalent SDY sharding.
- Issue was in
test_aot_out_info- Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in
lower_jaxpr_to_modulewhen adding the sdy MeshOp (there won't be any propagation)
- Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in
test_concurrent_pjit- In Shardy if there was a tensor dimension of size 0, we'd emit a verification error if the dimension is sharded on an axes. But if the axis is of size 1, then JAX says this is okay. So have shardy assume the same.
test_globally_sharded_key_array_result_8x4_single_device- This tests adds a WSC when no
mesh_shape_tupleexists ("sdy.sharding_constraint"(%8) <{sharding = #sdy.sharding<@mesh, [{?}, {?}, {}]>}>), so we should create a mesh namedmeshwith a single device id in case it doesn't exist.
- This tests adds a WSC when no
testLowerCostAnalysis- This calls into
mlir_module_to_xla_computationwhich calls its own MLIR parsing function in//third_party/tensorflow/compiler/xla/python/mlir.cc. Needed to register the SDY dialect in it.
- This calls into
testShardingConstraintWithArray- This calls
.compiler_ir(dialect="hlo")which callsPyMlirModuleToXlaComputationwhich converts the MLIR to HLO, but the Sdy dialect is still inside. Export it before converting it to HLO.
- This calls