Jake Vanderplas
Jake Vanderplas
Here's an example of how this was handled in the past: https://github.com/google/jax/blob/jax-v0.4.12/jax/_src/dtypes.py#L71 Basically, we only define the dtype in JAX if it's defined in `ml_dtypes`. Another strategy we could use...
Please fix the lint issues – thanks! Also, the test failures look real. It seems that there's some place where the new float8 types must be registered
I'm not sure how the stablehlo version is pinned. Maybe @hawkinsp knows?
This is intended: `jax` at HEAD should always run correctly with the latest `jaxlib` release (and in general should be compatible with `minimum_jaxlib_version`). This lets you iterate on the Python...
We're seeing some internal failures on GPU and TPU backends. I'll try to debug.
The error is this: ``` APITest.test_jit_custom_floats_float8_e4m3: ... "/build/.../jax/_src/array.py", [line 624](jax/_src/array.py?l=624), in _value self._npy_value = self._single_device_array_to_np_array() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ xla.python.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Unsupported type in PrimitiveTypeToDataType 28 ``` Something to do with one of...
We're still seeing some new failures in the PJRT runtime – I'm not sure how to address those. @hawkinsp do you have thoughts on how to proceed here?
v0.4.23 is fairly old; I'd suggest trying with a more recent JAX version, particularly if you are using more recent CUDA versions. > I tried this specific version because it...
Thanks for the report! So the tricky thing here is that `jax.scipy.sparse.linalg.cg` implements the API of `scipy.sparse.linalg.cg`, which has no explicit `symmetric` flag. We could probably add an optional `symmetric`...
It this a change you're interested in contributing?