jax
jax copied to clipboard
Metal : fp64 operations with jax.numpy base functions not supported
Description
Running the following
import jax
jax.config.update("jax_enable_x64", True)
jax.numpy.linspace(0, 1, 10)
produces
XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
...
<unknown>:0: error: failed to legalize operation 'func.func'
Also tested with logspace
, arange
, etc.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.23
numpy: 1.26.4
python: 3.11.4 (main, Jun 19 2023, 22:36:35) [Clang 14.0.3 (clang-1403.0.22.14.1)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='altair', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:49 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64')
Fp64 is not supported in jax-metal backend and is unknown when the support will be there as of now. We will post the update here if the situation changes.
Previous report: #16435
Thanks for the prompt reply. And sorry the previous issue did not catch my eye.
If I may suggest, it would be relevant to have such information displayed on https://developer.apple.com/metal/jax/
Thanks for the hard work !