jax
jax copied to clipboard
if keep_unused, still apply dce but don't prune inputs
nit: There's one more file that uses _prune_unused_inputs: jax/tests/xla_interpreter_test.py that you will need to update to use _dce_helper