tree-math icon indicating copy to clipboard operation
tree-math copied to clipboard

Operations should allow shape broadcasting

Open GeoffNN opened this issue 3 years ago • 5 comments

It seems like the current implementation doesn't allow broadcasting arguments. Here's an example for normalizing leafs.

import tree_math as tm
import jax
import jax.numpy as jnp

a = jnp.ones(10)
b = jnp.ones(5)

v = tm.Vector({'a': a, 'b': b})
v / jax.tree_map(jnp.linalg.norm, v)

returns the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-19-c2a8ad9c2f8f> in <module>()
----> 1 v / jax.tree_map(jnp.linalg.norm, v)

2 frames
/usr/local/lib/python3.7/dist-packages/tree_math/_src/vector.py in wrapper(self, other)
     72   """Implement a forward binary method, e.g., __add__."""
     73   def wrapper(self, other):
---> 74     return broadcasting_map(func, self, other)
     75   wrapper.__name__ = f"__{name}__"
     76   return wrapper

/usr/local/lib/python3.7/dist-packages/tree_math/_src/vector.py in broadcasting_map(func, *args)
     65   if not vector_args:
     66     return func2()  # result is a scalar
---> 67   _flatten_together(*[arg.tree for arg in vector_args])  # check shapes
     68   return tree_util.tree_map(func2, *vector_args)
     69 

/usr/local/lib/python3.7/dist-packages/tree_math/_src/vector.py in _flatten_together(*args)
     37   if not all(shapes == all_shapes[0] for shapes in all_shapes[1:]):
     38     shapes_str = " vs ".join(map(str, all_shapes))
---> 39     raise ValueError(f"tree leaves have different array shapes: {shapes_str}")
     40 
     41   return all_values, all_treedefs[0]

ValueError: tree leaves have different array shapes: [(10,), (5,)] vs [(), ()]

GeoffNN avatar Jan 03 '22 04:01 GeoffNN

The intention of the code makes sense, however according to the current semantics of Vector the operation seems weird because v has shape (15,) but the after the norm per leaf it has shape (2,), and division between such vectors would make no sense.

Maybe we need a looser abstraction (e.g. Numeric) that just lets you do leaf wise math operations but doesn't have array-like semantics.

cgarciae avatar Jan 03 '22 16:01 cgarciae

I think I see. I was somehow expecting Vector to allow leaf-wise math. The abstraction you suggest could be nice to avoid tree_map calls in general. For the above example, having something like normalized_tree = tree / tm.norm(tree), with tree an instance of Numeric may be ideal. One example use-case is implementing power iteration with a pytree input.

GeoffNN avatar Jan 03 '22 21:01 GeoffNN

Leaf-wise math sounds potentially useful but as @cgarciae notes it's definitely a different data model. I would love to see a more fully fleshed out use-case for this functionality.

shoyer avatar Jan 04 '22 23:01 shoyer

Since a lot of the operations from Vector are also leaf-wise there is the possibility that Numeric / NumericMixin might serve as a base class to Vector and friends. One of the big differences with Vector would be the dot method, it would either have to also be leaf-wise and return a pytree or just not included as an operation.

cgarciae avatar Jan 05 '22 21:01 cgarciae

My mental model for objects like this is that they are something closer to ragged arrays, e.g., a 2D matrix where the first dimension corresponds to the number of pytree leaves, and the second dimension corresponds to the (variable) size of each leaf.

shoyer avatar Jan 05 '22 21:01 shoyer