equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Alternative Sharing

Open bowlingmh opened this issue 7 months ago • 3 comments

I was looking at how sharing was added into equinox, and was a bit surprised at the approach taken. I had a similar challenge where I had a PyTree with shared jax.Arrays, and jit'ing a function that took such a PyTree as input ended up being really slow (presumably as it copied the same array over and over on device).

I took a different approach, and I'm wondering if this might be a better or alternative choice for equinox.

Obviously a PyTree can represent a DAG, as you can simply have two sub-PyTrees be, in fact, the same object. The problem is that running that through a jax transformation (namely jit) treats the repeated objects as independent. So why not use the same partition/combine equinox idea to do an unshare/share separation around the problematic jax transformations. Here unshare removes duplication, using a placeholder to store the duplication, and share returns the duplication on the inside of the transformation.

Here's what I implemented to solve the problem. The key is a meta-decorator share_through that does this unshare/share around some other transformation decorator. In my case, I use it as @share_through(eqx.filter_jit) on top of a function I want to jit that takes PyTree's with shared structure.

class _ShareIndex(int):
  def __repr__(self):
    return f'*{int(self)}*'
  

def _tree_unshare(x):
  leaves, treedef = jtu.tree_flatten(x)
  
  ids = [ id(l) for l in leaves ]
  
  shared_leaves = []
  shared_ids = {}
  for l in leaves:
    if id(l) in shared_ids:
      shared_leaves.append(_ShareIndex(shared_ids[id(l)]))
    else:
      shared_leaves.append(l)
      shared_ids[id(l)] = len(shared_leaves) - 1

  return jtu.tree_unflatten(treedef, shared_leaves)


def _tree_share(x):
  leaves, treedef = jtu.tree_flatten(x)

  leaves = [leaves[l] if isinstance(l, _ShareIndex) else l for l in leaves]

  return jtu.tree_unflatten(treedef, leaves)


def share_through(dec):
  def _dec(f):
    def _f_inner(*args, **kwargs):
      args = tuple(_tree_share(a) for a in args)
      kwargs = {k: _tree_share(v) for k,v in kwargs.items()}
      return _tree_unshare(f(*args, **kwargs))

    _dec_f = dec(_f_inner)

    @functools.wraps(f)
    def _f_outer(*args, **kwargs):
      args = tuple(_tree_unshare(a) for a in args)
      kwargs = {k: _tree_unshare(v) for k,v in kwargs.items()}
      return _tree_share(_dec_f(*args, **kwargs))

    return _f_outer

  return _dec

Here, you don't need to specifically label anything as shared. No special construction or labelling of the PyTree. No extra programmatic syntax to manipulate the shared PyTrees. You can specify sharing by simply sharing, and if you don't want something shared make sure it's a copy and not the same object. This seems to fit equinox's ethos perfectly. So maybe this just be built into eqx.filter_jit.

The above is limited to sharing at leaves (which is where all the dynamic arrays will be anyway), but it could relatively easily be modified to allow shared internal subtrees where stubs encode paths to where the original subtree lives.

Another limitation is that two PyTrees with two different sharing structures (e.g., one has shared leaves, and the other doesn't) will result in a re-JIT, as the input signature would be tied to its sharing structure. But overall, this seems like desirable behavior.

Maybe there's other limitations that prevent this from handling all that one would want sharing to handle?

Thoughts?

I could build this into a PR for filter_jit, but don't want to take the time if there's no interest.

bowlingmh avatar Nov 24 '23 21:11 bowlingmh