tree-math icon indicating copy to clipboard operation
tree-math copied to clipboard

`tm.unwrap`: Error when `out_vectors` is list [documentation]

Open deasmhumhna opened this issue 2 years ago • 1 comments

import jax.numpy as jnp 
import tree_math as tm

def f(x, y):
  return x, y
  
x = y = tm.Vector(jnp.array(0.))

tm.unwrap(f, out_vectors = (True, False))(x, y)
# (tree_math.Vector(DeviceArray(0., dtype=float32, weak_type=True)), DeviceArray(0., dtype=float32, weak_type=True))
tm.unwrap(f, out_vectors = [True, False])(x, y)
# ValueError: Expected list, got (DeviceArray(0., dtype=float32, weak_type=True), DeviceArray(0., dtype=float32, weak_type=True)).

deasmhumhna avatar Nov 10 '22 21:11 deasmhumhna