FJFormer
FJFormer copied to clipboard
[fix] Apply `with_sharding_constraint` recursively to pytrees
[generated by copilot]
This pull request refactors and enhances the handling of sharding constraints in the eformer/escale/partition/constraints.py file. The most notable changes include renaming and improving an existing function for applying sharding constraints and introducing a new function to handle PyTrees of JAX arrays with enhanced validation and correction logic.
Function renaming and improvement:
- Renamed
with_sharding_constrainttoarray_with_sharding_constraintto clarify its purpose as operating on a single JAX array. Updated the function's type annotations to usejax.Arrayfor improved clarity and consistency.
New functionality for PyTrees:
- Introduced a new
with_sharding_constraintfunction to apply sharding constraints to PyTrees of JAX arrays. This function validates the compatibility of the input PyTree structure and sharding specification, ensures all elements in the sharding specification are valid types, and applies corrections to incompatible sharding axes.