warp
warp copied to clipboard
[BUG] Using jax_callable() with GraphMode.WARP results in too many graph captures
Bug Description
When JAX doesn't reuse memory for the callable inputs and outputs, a new graph must be captured, which kills perf.
System Information
No response