jax icon indicating copy to clipboard operation
jax copied to clipboard

Metal : fp64 operations with jax.numpy base functions not supported

Open aboucaud opened this issue 10 months ago • 3 comments

Description

Running the following

import jax
jax.config.update("jax_enable_x64", True)
jax.numpy.linspace(0, 1, 10)

produces

XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
...
<unknown>:0: error: failed to legalize operation 'func.func'

Also tested with logspace, arange, etc.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.23
numpy:  1.26.4
python: 3.11.4 (main, Jun 19 2023, 22:36:35) [Clang 14.0.3 (clang-1403.0.22.14.1)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='altair', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:49 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64')

aboucaud avatar Apr 25 '24 19:04 aboucaud

Fp64 is not supported in jax-metal backend and is unknown when the support will be there as of now. We will post the update here if the situation changes.

shuhand0 avatar Apr 25 '24 21:04 shuhand0

Previous report: #16435

jakevdp avatar Apr 25 '24 21:04 jakevdp

Thanks for the prompt reply. And sorry the previous issue did not catch my eye.

If I may suggest, it would be relevant to have such information displayed on https://developer.apple.com/metal/jax/

aboucaud avatar Apr 25 '24 21:04 aboucaud

Thanks for the hard work !

aboucaud avatar Jun 15 '24 03:06 aboucaud