jax icon indicating copy to clipboard operation
jax copied to clipboard

Adding `tree_util.stack_leaves()` and `tree_util.unstack_leaves()`

Open ayaka14732 opened this issue 1 year ago • 5 comments

  • stack_leaves: Stack the leaves of one or more PyTrees along a new axis.
  • unstack_leaves: Unstack the leaves of a PyTree.

References:

  • https://docs.liesel-project.org/en/v0.1.4/_modules/liesel/goose/pytree.html#stack_leaves
  • https://gist.github.com/willwhitney/dd89cac6a5b771ccff18b06b33372c75?permalink_comment_id=4634557#gistcomment-4634557
  • https://github.com/ayaka14732/llama-2-jax/blob/ab33e1f15489daa8b9040389c77e486cd450e461/lib/tree_utils/init.py

ayaka14732 avatar Apr 25 '24 13:04 ayaka14732

To be clear, are these the semantics you have in mind?

def stack_leaves(pytrees, axis):
  return jax.tree.map(lambda *xs: jnp.stack(xs, axis), pytrees)

jakevdp avatar Apr 25 '24 14:04 jakevdp

To be clear, are these the semantics you have in mind?

Yes

ayaka14732 avatar Apr 25 '24 15:04 ayaka14732

For something like this, I'd probably lean toward recommending users implement what they need via existing API composability, rather than providing a new API for something that can already be pretty succinctly expressed. What do you think?

jakevdp avatar Apr 25 '24 16:04 jakevdp

Maybe adding tree util cookbook would be useful? @jakevdp

ASEM000 avatar Apr 26 '24 13:04 ASEM000

A pytree cookbook would be an interesting idea! This idea also came up in #20594. @ayaka14732, is that something you'd be interested in thinking about?

jakevdp avatar Apr 26 '24 15:04 jakevdp