[Pallas TPU] Add lowering for `lax.erf_inv` 64 bit
Add lowering for lax.erf_inv 64 bit. This is a follow-up of #22282
I don't think we test the lowering logic in x64 and at least on GPU there are known bugs with x64 support. So, I would maybe delay this until we have a concrete use-case.
I am looking at the test failures on GPU
Are we testing 64 bit on GPU at all? There is a skipTest:
https://github.com/google/jax/blob/31902775e2cc5ecb0433e520e9f12d416eda4479/tests/pallas/ops_test.py#L247-L249
We run some tests on x64, but not all, yeah.
Also added two additional test cases x = ±0.999517 for cases when 6.25 <= -jnp.log1p(x * -x) < 16.0