jax icon indicating copy to clipboard operation
jax copied to clipboard

[Pallas TPU] Add lowering for `lax.erf_inv` 64 bit

Open ayaka14732 opened this issue 1 year ago • 4 comments

Add lowering for lax.erf_inv 64 bit. This is a follow-up of #22282

ayaka14732 avatar Jul 08 '24 11:07 ayaka14732

I don't think we test the lowering logic in x64 and at least on GPU there are known bugs with x64 support. So, I would maybe delay this until we have a concrete use-case.

superbobry avatar Jul 08 '24 11:07 superbobry

I am looking at the test failures on GPU

ayaka14732 avatar Sep 11 '24 00:09 ayaka14732

Are we testing 64 bit on GPU at all? There is a skipTest:

https://github.com/google/jax/blob/31902775e2cc5ecb0433e520e9f12d416eda4479/tests/pallas/ops_test.py#L247-L249

ayaka14732 avatar Sep 11 '24 12:09 ayaka14732

We run some tests on x64, but not all, yeah.

superbobry avatar Sep 11 '24 12:09 superbobry

Also added two additional test cases x = ±0.999517 for cases when 6.25 <= -jnp.log1p(x * -x) < 16.0

ayaka14732 avatar Sep 19 '24 15:09 ayaka14732