tree-math
tree-math copied to clipboard
Operations should allow shape broadcasting
It seems like the current implementation doesn't allow broadcasting arguments. Here's an example for normalizing leafs.
import tree_math as tm
import jax
import jax.numpy as jnp
a = jnp.ones(10)
b = jnp.ones(5)
v = tm.Vector({'a': a, 'b': b})
v / jax.tree_map(jnp.linalg.norm, v)
returns the following error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-19-c2a8ad9c2f8f> in <module>()
----> 1 v / jax.tree_map(jnp.linalg.norm, v)
2 frames
/usr/local/lib/python3.7/dist-packages/tree_math/_src/vector.py in wrapper(self, other)
72 """Implement a forward binary method, e.g., __add__."""
73 def wrapper(self, other):
---> 74 return broadcasting_map(func, self, other)
75 wrapper.__name__ = f"__{name}__"
76 return wrapper
/usr/local/lib/python3.7/dist-packages/tree_math/_src/vector.py in broadcasting_map(func, *args)
65 if not vector_args:
66 return func2() # result is a scalar
---> 67 _flatten_together(*[arg.tree for arg in vector_args]) # check shapes
68 return tree_util.tree_map(func2, *vector_args)
69
/usr/local/lib/python3.7/dist-packages/tree_math/_src/vector.py in _flatten_together(*args)
37 if not all(shapes == all_shapes[0] for shapes in all_shapes[1:]):
38 shapes_str = " vs ".join(map(str, all_shapes))
---> 39 raise ValueError(f"tree leaves have different array shapes: {shapes_str}")
40
41 return all_values, all_treedefs[0]
ValueError: tree leaves have different array shapes: [(10,), (5,)] vs [(), ()]
The intention of the code makes sense, however according to the current semantics of Vector
the operation seems weird because v
has shape (15,)
but the after the norm per leaf it has shape (2,)
, and division between such vectors would make no sense.
Maybe we need a looser abstraction (e.g. Numeric
) that just lets you do leaf wise math operations but doesn't have array-like semantics.
I think I see. I was somehow expecting Vector
to allow leaf-wise math. The abstraction you suggest could be nice to avoid tree_map
calls in general. For the above example, having something like
normalized_tree = tree / tm.norm(tree)
, with tree
an instance of Numeric
may be ideal. One example use-case is implementing power iteration with a pytree input.
Leaf-wise math sounds potentially useful but as @cgarciae notes it's definitely a different data model. I would love to see a more fully fleshed out use-case for this functionality.
Since a lot of the operations from Vector
are also leaf-wise there is the possibility that Numeric
/ NumericMixin
might serve as a base class to Vector
and friends. One of the big differences with Vector
would be the dot
method, it would either have to also be leaf-wise and return a pytree or just not included as an operation.
My mental model for objects like this is that they are something closer to ragged arrays, e.g., a 2D matrix where the first dimension corresponds to the number of pytree leaves, and the second dimension corresponds to the (variable) size of each leaf.