fax icon indicating copy to clipboard operation
fax copied to clipboard

XLA translation rule for primitive 'two_phase_op_lin' not found

Open phinate opened this issue 4 years ago • 1 comments

Hi fax team!

We're still making use of the two_phase method to implicitly differentiate through a maximum likelihood fit. In this context, I recently tried to add an additional root-finding computation after this that involves using jax.lax.custom_root to try and make the output differentiable.

I can evaluate the whole forward computation with no issue, but when trying to call jax.grad on the result of this computation, I run into this error: NotImplementedError: XLA translation rule for primitive 'two_phase_op_lin' not found

I'm not too well versed in the low-level details in jax or fax, but I thought since there's two_phase in the primitive name, it may be something on the fax side that's throwing this issue.

The code is a little involved right now, but if this error isn't clear enough, I can try to deconstruct it to minimally reproduce this error if needed.

Thanks again for a great tool :)

phinate avatar Jun 29 '20 11:06 phinate

It sounds like something is trying to take the transpose of the two phase solver. This might happen if jax was trying to implement forward differentiation from backwards differentiation (or vice-versa). Unfortunately, I can't really say much from the info you gave us.

The full stack trace might help me figure out what is happening, but code to reproduce this issue would be best. Otherwise, some more detail about how you are calling the two phase method and jax.lax.custom_root would be helpful.

What is the motivation for this? Are you trying to take higher order derivatives?

gehring avatar Jun 29 '20 18:06 gehring