jax
jax copied to clipboard
make pe.inline_jaxpr_into_trace work with dynamic shapes
To make inline_jaxpr_into_trace
work with dynamic shapes, we need to perform variable substitution into the types of the jaxpr eqns being inlined. We also need to substitute in the appropriate tracers into the tracers produced corresponding to the inlined jaxpr's outvars.
The only difference is when jax_dynamic_shapes=True
, so this shouldn't affect anything other than our own tests.