pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Create `vectorized`, `value_and_grad` and `shape` versions of `JAXOp`

Open ricardoV94 opened this issue 2 months ago • 0 comments

Description

When we have a JAXOp in the final graph in a non-jax backend we may want to manipulate the JAX Op for efficiency. We could rewrite Blockwise(JAXOp) -> JAXOp whose inner function is vectorized.

If we have both the Op and the gradient, we could rewrite into a single op that uses value_and_grad under the hood.

And similarly if we only need the shape we could rewrite into an Op whose internal function only computes the shape. This last one is only relevant if the original Op doesn't remain in the graph.

ricardoV94 avatar Oct 09 '25 10:10 ricardoV94