jax icon indicating copy to clipboard operation
jax copied to clipboard

make pe.inline_jaxpr_into_trace work with dynamic shapes

Open mattjj opened this issue 7 months ago • 0 comments

To make inline_jaxpr_into_trace work with dynamic shapes, we need to perform variable substitution into the types of the jaxpr eqns being inlined. We also need to substitute in the appropriate tracers into the tracers produced corresponding to the inlined jaxpr's outvars.

The only difference is when jax_dynamic_shapes=True, so this shouldn't affect anything other than our own tests.

mattjj avatar Jun 29 '24 00:06 mattjj