jax
jax copied to clipboard
Strange behavior of `convert_element_type` on main branch
Running from github HEAD:
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
x = jnp.float64(2.718)
x_f16 = x.astype('float16')
print(x, x_f16)
# 2.718 -15790.0
version info:
jax.print_environment_info()
# jax: 0.3.18
# jaxlib: 0.3.15
# numpy: 1.23.2
# python: 3.8.2 (v3.8.2:7b3ab5921f, Feb 24 2020, 17:52:18) [Clang 6.0 (clang-600.0.57)]
# jax.devices (1 total, 1 local): [CpuDevice(id=0)]
# process_count: 1
I can only reproduce this locally on my macbook; I've not been able to reproduce in Colab or on linux.
FWIW, this does not reproduce on my macbook, albeit with a (locally built) jaxlib 0.3.18, also at HEAD:
>>> import jax.numpy as jnp
>>> import jax
>>> jax.config.update("jax_enable_x64", True)
>>> x = jnp.float64(2.718)
>>> x_f16 = x.astype("float16")
>>> print(x, x_f16)
2.718 2.719
>>> jax.print_environment_info()
jax: 0.3.18
jaxlib: 0.3.18
numpy: 1.23.3
python: 3.10.6 (main, Aug 30 2022, 04:58:14) [Clang 13.1.6 (clang-1316.0.21.2.5)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
It does reproduce on my MacBook, with jaxlib 0.3.15, and I think that's where the problem lies.
jax.print_environment_info()
# jax: 0.3.18
# jaxlib: 0.3.15
# numpy: 1.23.3
# python: 3.10.6 (main, Aug 30 2022, 05:12:36) [Clang 13.1.6 (clang-1316.0.21.2.5)]
# jax.devices (1 total, 1 local): [CpuDevice(id=0)]
# process_count: 1
This does not reproduce on my M1 Macbook pro with either jaxlib 0.3.15 or jaxlib from head.
For those of you where this fails, are you using Intel or ARM macbooks? Can you try building jaxlib from head and see if it reproduces? I'm speculating this is probably already fixed with an up to date jaxlib.
Mine is an intel macbook
Mine is also an intel MacBook. And the problem persists with a locally built jaxlib from HEAD.
In [1]: import jax.numpy as jnp
...: import jax
...:
...: jax.config.update('jax_enable_x64', True)
...:
...: x = jnp.float64(2.718)
...: x_f16 = x.astype('float16')
...:
...: print(x, x_f16)
...: # 2.718 -15790.0
2.718 -15790.0
In [2]: jax.print_environment_info()
jax: 0.3.18
jaxlib: 0.3.18
numpy: 1.23.3
python: 3.10.6 (main, Aug 30 2022, 05:12:36) [Clang 13.1.6 (clang-1316.0.21.2.5)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
Could you share the details of what CPU you have? Sharing the output of sysctl -a | grep machdep.cpu should do it.
@hawkinsp Here you go.
machdep.cpu.mwait.linesize_min: 64
machdep.cpu.mwait.linesize_max: 64
machdep.cpu.mwait.extensions: 3
machdep.cpu.mwait.sub_Cstates: 286531872
machdep.cpu.thermal.sensor: 1
machdep.cpu.thermal.dynamic_acceleration: 1
machdep.cpu.thermal.invariant_APIC_timer: 1
machdep.cpu.thermal.thresholds: 2
machdep.cpu.thermal.ACNT_MCNT: 1
machdep.cpu.thermal.core_power_limits: 1
machdep.cpu.thermal.fine_grain_clock_mod: 1
machdep.cpu.thermal.package_thermal_intr: 1
machdep.cpu.thermal.hardware_feedback: 0
machdep.cpu.thermal.energy_policy: 1
machdep.cpu.xsave.extended_state: 31 832 1088 0
machdep.cpu.xsave.extended_state1: 15 832 256 0
machdep.cpu.arch_perf.version: 4
machdep.cpu.arch_perf.number: 4
machdep.cpu.arch_perf.width: 48
machdep.cpu.arch_perf.events_number: 7
machdep.cpu.arch_perf.events: 0
machdep.cpu.arch_perf.fixed_number: 3
machdep.cpu.arch_perf.fixed_width: 48
machdep.cpu.cache.linesize: 64
machdep.cpu.cache.L2_associativity: 4
machdep.cpu.cache.size: 256
machdep.cpu.tlb.inst.large: 8
machdep.cpu.tlb.data.small: 64
machdep.cpu.tlb.data.small_level1: 64
machdep.cpu.address_bits.physical: 39
machdep.cpu.address_bits.virtual: 48
machdep.cpu.tsc_ccc.numerator: 200
machdep.cpu.tsc_ccc.denominator: 2
machdep.cpu.max_basic: 22
machdep.cpu.max_ext: 2147483656
machdep.cpu.vendor: GenuineIntel
machdep.cpu.brand_string: Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
machdep.cpu.family: 6
machdep.cpu.model: 158
machdep.cpu.extmodel: 9
machdep.cpu.extfamily: 0
machdep.cpu.stepping: 13
machdep.cpu.feature_bits: 9221959987971750911
machdep.cpu.leaf7_feature_bits: 43804591 1073741824
machdep.cpu.leaf7_feature_bits_edx: 3154120192
machdep.cpu.extfeature_bits: 1241984796928
machdep.cpu.signature: 591597
machdep.cpu.brand: 0
machdep.cpu.features: FPU VME DE PSE TSC MSR PAE MCE CX8 APIC SEP MTRR PGE MCA CMOV PAT PSE36 CLFSH DS ACPI MMX FXSR SSE SSE2 SS HTT TM PBE SSE3 PCLMULQDQ DTES64 MON DSCPL VMX EST TM2 SSSE3 FMA CX16 TPR PDCM SSE4.1 SSE4.2 x2APIC MOVBE POPCNT AES PCID XSAVE OSXSAVE SEGLIM64 TSCTMR AVX1.0 RDRAND F16C
machdep.cpu.leaf7_features: RDWRFSGS TSC_THREAD_OFFSET SGX BMI1 AVX2 SMEP BMI2 ERMS INVPCID FPU_CSDS MPX RDSEED ADX SMAP CLFSOPT IPT SGXLC MDCLEAR IBRS STIBP L1DF ACAPMSR SSBD
machdep.cpu.extfeatures: SYSCALL XD 1GBPAGE EM64T LAHF LZCNT PREFETCHW RDTSCP TSCI
machdep.cpu.logical_per_package: 16
machdep.cpu.cores_per_package: 8
machdep.cpu.microcode_version: 240
machdep.cpu.processor_flag: 5
machdep.cpu.core_count: 8
machdep.cpu.thread_count: 16
Could someone who experiences this problem please run the reproduction with the environment variable XLA_FLAGS=--xla_dump_to=/tmp/somewhere and share a zip file or similar with the directory of files it produces?
@hawkinsp Here you go again. dump.zip
My current best guess is that this may be related to the _Float16 ABI changing in LLVM.
This program in essence compiles into a single call to a builtin named __truncdfhf2. The only explanation I can think of is that there is a mismatch of calling conventions for Mac x86-64.
Notably this changed in LLVM recently: https://reviews.llvm.org/D131172
LLVM changed the x86-64 fp16 abi from passing the value in integer registers to passing it in floating point registers. This means that __truncdfhf2 that's coming from the system is now incompatible with LLVM in jaxlib.
If I remember correctly __truncdfhf2 is provided by the system on macOS, not XCode. This would mean we're stuck on the wrong ABI until Apple decides to change it. Which seems unlikely to happen outside of a major release. I'm curious what will happen when they hit this change when updating Clang in XCode though.
This isn't an issue for f32->f16 because there's been a hardware instruction for it since Haswell, but there's none for f64->f16 (it exists in AVX512, but no mac was released with it).
I think we can work around this by making simple_orc_jit bind __truncdfhf2 to the fallback version in runtime_fp16.cc, which I fixed to use the correct ABI. I have no Intel mac around to test that.