tree-math icon indicating copy to clipboard operation
tree-math copied to clipboard

How well does tree-math support computation on multiple devices?

Open connection-on-fiber-bundles opened this issue 2 years ago • 1 comments

Wondering how well tree-math supports computation on multiple devices?

Let's say we have a pytree of tensors of different dimensions and want to perform some operations on each of them with tree-math, can we distribute those tasks to multiple devices (GPU, for instance)?

This should not be a problem. Tree-Math is entirely agnostic to JAX's multi-device APIs. It's just syntactic sugar for jax.tree_util.tree_map.

shoyer avatar Apr 25 '22 18:04 shoyer