fax
fax copied to clipboard
XLA translation rule for primitive 'two_phase_op_lin' not found
Hi fax team!
We're still making use of the two_phase
method to implicitly differentiate through a maximum likelihood fit. In this context, I recently tried to add an additional root-finding computation after this that involves using jax.lax.custom_root
to try and make the output differentiable.
I can evaluate the whole forward computation with no issue, but when trying to call jax.grad
on the result of this computation, I run into this error:
NotImplementedError: XLA translation rule for primitive 'two_phase_op_lin' not found
I'm not too well versed in the low-level details in jax
or fax
, but I thought since there's two_phase
in the primitive name, it may be something on the fax
side that's throwing this issue.
The code is a little involved right now, but if this error isn't clear enough, I can try to deconstruct it to minimally reproduce this error if needed.
Thanks again for a great tool :)
It sounds like something is trying to take the transpose of the two phase solver. This might happen if jax
was trying to implement forward differentiation from backwards differentiation (or vice-versa). Unfortunately, I can't really say much from the info you gave us.
The full stack trace might help me figure out what is happening, but code to reproduce this issue would be best. Otherwise, some more detail about how you are calling the two phase method and jax.lax.custom_root
would be helpful.
What is the motivation for this? Are you trying to take higher order derivatives?