tree-math
tree-math copied to clipboard
Mathematical operations for JAX pytrees
How should I go about importing `struct`? Thanks! My attempt below fails -- ``` import tree_math @tree_math.struct class Point: x: float y: float ``` AttributeError: module 'tree_math' has no attribute...
Wondering how well tree-math supports computation on multiple devices? Let's say we have a pytree of tensors of different dimensions and want to perform some operations on each of them...
It would be nice to have an easy way to define dataclasses that are also tree-math vectors. We could borrow the syntax of flax.struct here: https://flax.readthedocs.io/en/latest/flax.struct.html Example usage: ``` from...
I'd like to add a `Matrix` class to complement `Vector`. A key design question is what this needs to support. In particular: do we need to support multiple axes that...
It seems like the current implementation doesn't allow broadcasting arguments. Here's an example for normalizing leafs. ```python import tree_math as tm import jax import jax.numpy as jnp a = jnp.ones(10)...
These are rather undescriptive names, and I guess most people (including myself!) will have to guess & check to keep them straight. :) Some other possibilities, in rough order of...
`save-state` and `set-output` commands used in GitHub Actions are deprecated and [GitHub recommends using environment files](https://github.blog/changelog/2023-07-24-github-actions-update-on-save-state-and-set-output-commands/). This PR updates the usage of `::set-output` to `"$GITHUB_OUTPUT"` Instructions for envvar usage from...
it would be nice to have some fields be pytree nodes and others not, which would make this a full replacement for `flax.struct`
```python import jax.numpy as jnp import tree_math as tm def f(x, y): return x, y x = y = tm.Vector(jnp.array(0.)) tm.unwrap(f, out_vectors = (True, False))(x, y) # (tree_math.Vector(DeviceArray(0., dtype=float32, weak_type=True)),...