einx
einx copied to clipboard
Segment_sum
Is there a way to do the equivalent of segment_sum (and product, softmax, etc) in einx?
https://docs.jax.dev/en/latest/_autosummary/jax.ops.segment_sum.html jax.ops.segment_sum — JAX documentation
I found that there is a way to do this with einx using "add_at".
But to correctly implement logmax, we also need the equivalent of "max_at" or "min_at". Can this be added in?