tree-math
tree-math copied to clipboard
How well does tree-math support computation on multiple devices?
Wondering how well tree-math supports computation on multiple devices?
Let's say we have a pytree of tensors of different dimensions and want to perform some operations on each of them with tree-math, can we distribute those tasks to multiple devices (GPU, for instance)?
This should not be a problem. Tree-Math is entirely agnostic to JAX's multi-device APIs. It's just syntactic sugar for jax.tree_util.tree_map
.