tree-math
tree-math copied to clipboard
`tm.unwrap`: Error when `out_vectors` is list [documentation]
import jax.numpy as jnp
import tree_math as tm
def f(x, y):
return x, y
x = y = tm.Vector(jnp.array(0.))
tm.unwrap(f, out_vectors = (True, False))(x, y)
# (tree_math.Vector(DeviceArray(0., dtype=float32, weak_type=True)), DeviceArray(0., dtype=float32, weak_type=True))
tm.unwrap(f, out_vectors = [True, False])(x, y)
# ValueError: Expected list, got (DeviceArray(0., dtype=float32, weak_type=True), DeviceArray(0., dtype=float32, weak_type=True)).