jax
jax copied to clipboard
Add float8_e4m3 and float8_e3m4 types support
Description
Amazon has proposed two new FP8 types, Float8E4M3 and Float8E3M4. These types are implemented in commercially available hardware Amazon EC2 Trn1 Instances, and added to MLIR builtin types, LLVM APFloat, ml_dtypes, StableHLO.
XLA has Float8E4M3 and Float8E3M4 implementation in Review. See PR links in Related PRs section below.
This PR adds f8E4M3 and f8E3M4 types support to JAX.
f8E4M3 type follows IEEE 754 convention.
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs
Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
f8E3M4 type follows IEEE 754 convention
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs
Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6)
Related PRs:
- LLVM PR-97179 [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM PR-97118 [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
- LLVM PR-101230 [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO PR-2486 [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO PR-2482 Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes PR-161 Add float8_e4m3 (Merged)
- ml_dtypes PR-171 Add float8_e3m4 (Merged)
- XLA PR-17075 [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Merged)
- XLA PR-16585 Add support for float8_e4m3 and float8_e3m4 types (in Review)
How to build/install
This PR requires ml_dtype version 20240821 or later.
The current version on PyPI is 0.4.0, released on April 1, 2024, which is outdated. Therefore, ml_dtypes should be installed from source.
Related issue: https://github.com/jax-ml/ml_dtypes/issues/185 [Question] Can we release a new version of ml_dtypes?
## Install the latest ml_dtypes
cd ml_dtypes
pip3 install .
## Install jaxlib and JAX
cd jax
### install jaxlib
python3 build/build.py
pip3 install dist/*.whl
### install jax
pip3 install .
Smoke test
import jax
import jax.numpy as jnp
from jax import Array, random
from jax._src.lib.mlir.dialects import hlo
jax.devices()
hlo.get_version_from_compatibility_requirement(
hlo.StablehloCompatibilityRequirement.WEEK_4
)
hlo.get_version_from_compatibility_requirement(
hlo.StablehloCompatibilityRequirement.WEEK_12
)
dtype = "float8_e4m3"
# dtype = "float8_e3m4"
key1 = random.PRNGKey(41)
key2 = random.PRNGKey(42)
a = random.uniform(key1, shape=(16, 16), dtype=dtype)
b = random.uniform(key2, shape=(16, 16), dtype=dtype)
def foo(a, b):
return a @ b
foo_jit = jax.jit(foo)
# StableHLO
print(foo_jit.lower(a, b).as_text())
# HLO
print(foo_jit.lower(a, b).compile().as_text())
c = foo(a, b).block_until_ready()
c2 = foo_jit(a, b).block_until_ready()
i = 0
while i < 10000:
c2 = foo_jit(a, b).block_until_ready()
i += 1
Array([[3.25, 2.75, 2.5, 2.5, 2.25, 2.25, 2.5, 3.25, 2.5, 3.25, 2, 2.25,
2.5, 3, 2.75, 2.75],
...
[4, 3.5, 3.5, 2.25, 2, 2.25, 3, 3.25, 2.25, 3, 2.75, 3, 2.5, 3.25,
2, 2.75]], dtype=float8_e4m3)
Thanks for the contribution! I don't think we'll be able to bump our ml_dtypes requirement any time soon, so if we want to merge this we'll have to make it robust to older ml_dtypes versions (the reason is that tensorflow pins a specific ml_dtypes version, and some workflows depend on installing both JAX and tensorflow.
The good news is this is easy enough to do with a few version guards: if you look at the initial implementation of float8 types in JAX, you can see the pattern we used previously.
Here's an example of how this was handled in the past: https://github.com/google/jax/blob/jax-v0.4.12/jax/_src/dtypes.py#L71
Basically, we only define the dtype in JAX if it's defined in ml_dtypes.
Another strategy we could use is the module-level __getattr__ for these types, so that if the ml_dtypes version is too old, we raise an error that specifies what version is required.
Incidentally, the current TF pin is : Requires-Dist: ml-dtypes <0.5.0,>=0.3.1.
If we release ml_dtypes as 0.4.1 instead of 0.5.0 we probably could bump the minimum version.
I suspect we could ease this process if we committed to semver for ml_dtypes so TF felt like they could be less conservative in their pins. (Adding dtypes is hopefully safe!)
That said I'd probably do it the way Jake said for now and then we can think about the minimum version bump separately, there may be other factors I haven't considered (e.g., users being stuck on an older TF for whatever reason).
That said I'd probably do it the way Jake said for now and then we can think about the minimum version bump separately, there may be other factors I haven't considered (e.g., users being stuck on an older TF for whatever reason).
I updated the PR and tested it with ml_dtypes 0.4.0 and 0.5.0 @jakevdp @hawkinsp
Please fix the lint issues – thanks! Also, the test failures look real. It seems that there's some place where the new float8 types must be registered
Please fix the lint issues – thanks! Also, the test failures look real. It seems that there's some place where the new float8 types must be registered
MyPy
fixed mypy issues
btw, Contributing to JAX explains how to run lint/ruff/mypy/jupytext locally
pre-commit run --all-files
All passed now
Regarding failed tests
FAILED tests/export_test.py::JaxExportTest::test_poly_numeric_dtypes_dtype_float8_e3m4
e4m3 and e3m4 were added to stablehlo 1.7.0 (3 weeks ago, Sep 4, 2024)
jax/_src/export/_export.py uses
target_version = hlo.get_version_from_compatibility_requirement(
hlo.StablehloCompatibilityRequirement.WEEK_4)
it returns target_version 1.5.0
workaround - use NONE instead of WEEK_4 - it returns 1.7.5
FAILED tests/array_test.py::JaxArrayTest::test_shards_have_correct_dtype17
FAILED tests/dtypes_test.py::TestPromotionTables::testFloat8PromotionError
The tests work fine if I use XLA COMMIT_ID from XLA PR https://github.com/openxla/xla/pull/16585 Add support for float8_e4m3 and float8_e3m4 types (in Review)
I guess we need to put this PR on hold and rerun the tests once XLA PR-16585 is merged to XLA main and XLA_COMMIT is updated in JAX. StableHLO WEEK_4 issue should resolve itself in 1-2 weeks too.
Test report on CPU
pytest -n auto tests/
26921 passed, 12530 skipped in 492.32s (0:08:12)
Update
The XLA PR #16585, which adds support for float8_e4m3 and float8_e3m4 types, has just been merged into the XLA main branch.
XLA COMMIT_ID=693ee2e13225331bebc946442af7e2d59355adea
I've noticed several "Update XLA dependency to use revision" commits in JAX's recent commit history. I expect that JAX will automatically update to include the above commit ID sometime this week.
Rebased to include the latest XLA with support for f8e4m3 and f8e3m4.
About stablehlo vhlo issue - get_version_from_compatibility_requirement(WEEK_4) still returns 1.5.0. (even though stablehlo 1.6.0 was released 49 days ago and 1.7.0 - 28 days ago According to VhloDialect.td).
As a result, test_poly_numeric_dtypes_dtype is still failing: Failed to serialize StableHLO.
FAILED tests/export_test.py::JaxExportTest::test_poly_numeric_dtypes_dtype_float8_e3m4 - jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Failed to serialize StableHLO;
FAILED tests/export_test.py::JaxExportTest::test_poly_numeric_dtypes_dtype_float8_e4m3 - jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Failed to serialize StableHLO;
Jake, could you advice on this issue? @jakevdp
I'm not sure how the stablehlo version is pinned. Maybe @hawkinsp knows?
The issue was fixed in this stablehlo PR - https://github.com/openxla/stablehlo/pull/2579
case CompatibilityRequirement::WEEK_4:
return Version(1, 5, 0); // v1.5.0 - Aug 1, 2024
return Version(1, 7, 0); // v1.7.0 - Sept 05, 2024
These versions are defined by the jaxlib, and https://github.com/openxla/stablehlo/pull/2579 isn't included in the v0.4.34 release from earlier today. That being said, I think it would be fine to land this change (with the caveat that I haven't reviewed closely) after adding these types here:
https://github.com/jax-ml/jax/blob/83b0a932bdc542df82c26a57c4b1777809621258/tests/export_test.py#L908-L914
to skip the offending tests for now.
StableHLO was fixed to return 1.7.0 for WEEK_4 req.
- XLA has been updated to use the latest StableHLO
- JAX has been updated to use the latest XLA
export_test::test_poly_numeric_dtypes works fine for f8e4m3 and f8e3m4 now.
All pytests passed now (on CPU)!
== 26919 passed, 12575 skipped in 516.53s (0:08:36) ==
Jake, can you help with rerunning github/test workflows?
I checked failed CI / build with 3.12 (py 3.12 on ubuntu-20.04-16core, x64=0) Logs
The test environment uses mix of JAX versions - jaxlib-0.4.34 and jax-0.4.35.dev.
Should I recreate this situation locally (jaxlib-0.4.34 and jax-0.4.35.dev) I add "workarounds/skips" to all failed test?
Yesterday I run pytest locally with both jax and jaxlib version 0.4.35.dev and all tests passed. @dfm @jakevdp
Te log below shows that py 3.12 test env uses mix of JAX versions - jaxlib-0.4.34 and jax-0.4.35.dev.
2024-10-07T13:05:52.3536429Z Collecting jaxlib<=0.4.35,>=0.4.34 (from jax==0.4.35.dev20241006+702ec6c)
2024-10-07T13:05:52.3548468Z Using cached jaxlib-0.4.34-cp312-cp312-manylinux2014_x86_64.whl.metadata (983 bytes)
2024-10-07T13:05:53.2087941Z Using cached jaxlib-0.4.34-cp312-cp312-manylinux2014_x86_64.whl (86.2 MB)
2024-10-07T13:06:07.4570054Z Successfully installed jax-0.4.35.dev20241006+702ec6c jaxlib-0.4.34
2024-10-07T13:06:10.3094007Z Requirement already satisfied: jaxlib<=0.4.35,>=0.4.34 in /opt/hostedtoolcache/Python/3.12.6/x64/lib/python3.12/site-packages (from jax==0.4.35.dev20241006+702ec6c) (0.4.34)
This is intended: jax at HEAD should always run correctly with the latest jaxlib release (and in general should be compatible with minimum_jaxlib_version). This lets you iterate on the Python parts of JAX without having to build jaxlib locally.
The general pattern we use to avoid version skew is to check xla_extension_version within the jax Python code, and if necessary increment it within jaxlib.
Updated test_util.py, dtypes_test.py and export_test.py to skip tests for f8e4m3 and f8e3m4 dtypes when jaxlib is below version 0.4.35.
Tested jax-0.4.35.dev with both jaxlib versions - 0.4.34 and 0.4.35.dev
>>> jax.__version__
'0.4.35.dev20241007+be74ecf1b'
>>> jaxlib.__version__
'0.4.34'
pytest -n auto tests/
=== 26881 passed, 12589 skipped in 564.83s (0:09:24) ===
All checks have passed. Is there anything else needed before it can be merged?
All checks passed except of
- docs/readthedocs.org:jax Pending Error Your project, organization, or user has reached its maximum number of concurrent builds allowed (2). This build will automatically retry in 5 minutes.
Not sure if it is actually trying to retry
We're seeing some internal failures on GPU and TPU backends. I'll try to debug.
Will try to build/test on GPU instance
The error is this:
APITest.test_jit_custom_floats_float8_e4m3:
...
"/build/.../jax/_src/array.py", [line 624](jax/_src/array.py?l=624), in _value
self._npy_value = self._single_device_array_to_np_array()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
xla.python.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Unsupported type in PrimitiveTypeToDataType 28
Something to do with one of the lower-level passes not knowing how to consume the new serialized types
I tested the PR on GPU instance (nvidia A10G 23GB) and found that dtypes_test.py, export_test.py and api_test.py failed for f8e3m4 and f8e4m3 types. Error: Failed to serialize StableHLO.
The issue occurs because PjRtCApiCompiler::Compile() calls xla::GetDefaultStablehloVersion(), which uses WEEK_12 and returns StableHLO version 1.1.0. However, these new types require version 1.7.0.
When I replace WEEK_12 with WEEK_4, the tests pass successfully on the GPU.
Temporary I limited f8e3m4 and f8e4m3 tests to run on "cpu" only.
# TODO: Remove "cpu" check once xla::GetDefaultStablehloVersion() is 1.7.0+
if device_under_test() == "cpu" and jax._src.lib.version >= (0, 4, 35):
...
I also opened an XLA PR to review the possibility to use WEEK_4 in xla::GetDefaultStablehloVersion() - https://github.com/openxla/xla/pull/18117
I tested the PR on GPU instance (nvidia A10G 23GB) and found that
dtypes_test.py,export_test.pyandapi_test.pyfailed forf8e3m4andf8e4m3types. Error:Failed to serialize StableHLO.The issue occurs because
PjRtCApiCompiler::Compile()callsxla::GetDefaultStablehloVersion(), which usesWEEK_12and returns StableHLO version 1.1.0. However, these new types require version 1.7.0.
Yeah, we're aware of this. We need to plumb the plugin's stablehlo version to JAX, which we haven't done yet. We should always be able to produce the newest stablehlo the consumer can consume, but currently we are overly conservative, I think. @dfm
- The message "Unsupported type in PrimitiveTypeToDataType" from the error report above does not appear to exist in XLA, ml_dtypes, JAX, or StableHLO. I found it only in TensorFlow, specifically in the file tensorflow/compiler/tf2xla/type_util.cc.
It is likely that the same code is utilized in some closed-source PJRT plugins/compilers.
device_under_test() == "cpu" should fix "Unsupported type in PrimitiveTypeToDataType" case in APITest.test_jit_custom_floats_float8_e4m3 too.
The tests work fine on GPU if WEEK_4 is used in xla::GetDefaultStablehloVersion()
=== jtu.device_under_test(): gpu
tests/api_test.py::APITest::test_jit_custom_floats_float8_e3m4 PASSED
tests/api_test.py::APITest::test_jit_custom_floats_float8_e4m3 PASSED
tests/dtypes_test.py::TestPromotionTables::testJaxTypeFromType_jaxtype=dtype(float8_e3m4) PASSED
tests/dtypes_test.py::TestPromotionTables::testJaxTypeFromType_jaxtype=dtype(float8_e4m3) PASSED
tests/dtypes_test.py::TestPromotionTables::testJaxTypeFromVal_jaxtype=dtype(float8_e3m4) PASSED
tests/dtypes_test.py::TestPromotionTables::testJaxTypeFromVal_jaxtype=dtype(float8_e4m3) PASSED
tests/dtypes_test.py::TestPromotionTables::testJaxTypeWeak_dtype=dtype(float8_e3m4) PASSED
tests/dtypes_test.py::TestPromotionTables::testJaxTypeWeak_dtype=dtype(float8_e4m3) PASSED
tests/export_test.py::JaxExportTest::test_poly_numeric_dtypes_dtype_float8_e3m4 PASSED
tests/export_test.py::JaxExportTest::test_poly_numeric_dtypes_dtype_float8_e4m3 PASSED
Jake, can you help with re-running the tests. I temporary limited f8e3m4 and f8e4m3 tests to run on "cpu" only. @jakevdp The tests work fine on GPU if WEEK_4 is used in xla::GetDefaultStablehloVersion()
Summary of PR Testing on GPU:
Installed packages:
- jax-0.4.35.dev
- jaxlib-0.4.35.dev
- ml_dtypes-0.5.0
get_version_from_compatibility_requirement(WEEK_12)=> StableHLO-1.1.0
Results:
- The smoke test script (provided in the PR description) that uses the dot operation successfully passed on the GPU, including tests using jax.jit
- All pytest tests passed on the GPU instance.*
*Note: Initial test failures occurred for f8e3m4 and f8e4m3 on the GPU due to the use of StableHLO version 1.1.0 (WEEK_12). Temporary workaround: These tests are currently restricted to run on "CPU" only. I plan to re-enable GPU testing for these types after November 28, 2024, when WEEK_12 will be updated to use StableHLO version 1.7.0.
Jake, Peter, Dan, Could you advise on the next steps for this PR?
@jakevdp @hawkinsp @dfm
We're still seeing some new failures in the PJRT runtime – I'm not sure how to address those. @hawkinsp do you have thoughts on how to proceed here?
StableHLO WEEK_12 is in the process of updating to version 1.5.0 (https://github.com/openxla/stablehlo/pull/2599)
WEEK_12 is scheduled to switch to version 1.7.0 (where f8e4m3 was added) in about 36 days (on or after November 28, 2024).
Are there any potential workarounds for the "failed internal tests for GPU and TPU backends" (reported here) that could allow this PR to be shipped earlier?
@hawkinsp @jakevdp @dfm
I retested this PR using jaxlib 0.4.35 (released Oct 22) on a server with 4 GPU devices.
The checks and workarounds I previously added to skip certain tests for f8e4m3/f8e3m4 types are no longer needed, so I removed them and rebased this PR.
All tests passed on the 4-GPU setup.
Jake, Peter, do you think we should attempt merging this PR again now that jaxlib 0.4.35 was released two weeks ago and all tests passed on GPU without any workarounds or skips? @jakevdp @hawkinsp
I think https://github.com/jax-ml/jax/pull/24956 will unblock merging this.
The problem is that we shouldn't be running for every custom dtype in the tests, if any given backend only supports a subset of them.