FJFormer icon indicating copy to clipboard operation
FJFormer copied to clipboard

[fix] Apply `with_sharding_constraint` recursively to pytrees

Open dvruette opened this issue 5 months ago • 0 comments

[generated by copilot]

This pull request refactors and enhances the handling of sharding constraints in the eformer/escale/partition/constraints.py file. The most notable changes include renaming and improving an existing function for applying sharding constraints and introducing a new function to handle PyTrees of JAX arrays with enhanced validation and correction logic.

Function renaming and improvement:

  • Renamed with_sharding_constraint to array_with_sharding_constraint to clarify its purpose as operating on a single JAX array. Updated the function's type annotations to use jax.Array for improved clarity and consistency.

New functionality for PyTrees:

  • Introduced a new with_sharding_constraint function to apply sharding constraints to PyTrees of JAX arrays. This function validates the compatibility of the input PyTree structure and sharding specification, ensures all elements in the sharding specification are valid types, and applies corrections to incompatible sharding axes.

dvruette avatar Aug 02 '25 09:08 dvruette