xarray_jax does not support jax.jit().lower
The JAX API now includes more detailed control over the compilation process with jax.stages, but the xarray_jax wrapper here in graphcast does not seem to support jax.jit().lower:
import graphcast.xarray_jax as xarray_jax
import jax.numpy as jnp
import jax
def ident(a): # Trivial test function
return a
# Sample variables
foo = jnp.ones(3)
foo_xr = xarray_jax.DataArray(foo)
print(jax.jit(ident)(foo)) # Works
# [1. 1. 1.]
print(jax.jit(ident)(foo_xr)) # Works
# <xarray.DataArray (dim_0: 3)>
# xarray_jax.JaxArrayWrapper(Array([1., 1., 1.], dtype=float32))
# Dimensions without coordinates: dim_0
jax.jit(ident).lower(foo) # Works
# <jax._src.stages.Lowered at 0x151bb5e04830>
jax.jit(ident).lower(foo_xr) # Fails
Traceback
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[13], line 1
----> 1 jax.jit(ident).lower(foo_xr) # Fails
[... skipping hidden 5 frame]
File /fs/site5/eccc/mrd/rpnatm/csu001/ppp5/graphcast_dev/graphcast/xarray_jax.py:668, in _unflatten_variable(aux, children)
666 dims_change_fn = _DIMS_CHANGE_ON_UNFLATTEN_FN.get(None)
667 if dims_change_fn: dims = dims_change_fn(dims)
--> 668 return Variable(dims=dims, data=children[0])
File /fs/site5/eccc/mrd/rpnatm/csu001/ppp5/graphcast_dev/graphcast/xarray_jax.py:113, in Variable(dims, data, **kwargs)
111 def Variable(dims, data, **kwargs) -> xarray.Variable: # pylint:disable=invalid-name
112 """Like xarray.Variable, but can wrap JAX arrays."""
--> 113 return xarray.Variable(dims, wrap(data), **kwargs)
File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/core/variable.py:365, in Variable.__init__(self, dims, data, attrs, encoding, fastpath)
338 def __init__(
339 self,
340 dims,
(...)
344 fastpath=False,
345 ):
346 """
347 Parameters
348 ----------
(...)
363 unrecognized encoding items.
364 """
--> 365 super().__init__(
366 dims=dims, data=as_compatible_data(data, fastpath=fastpath), attrs=attrs
367 )
369 self._encoding = None
370 if encoding is not None:
File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/namedarray/core.py:253, in NamedArray.__init__(self, dims, data, attrs)
246 def __init__(
247 self,
248 dims: _DimsLike,
249 data: duckarray[Any, _DType_co],
250 attrs: _AttrsLike = None,
251 ):
252 self._data = data
--> 253 self._dims = self._parse_dimensions(dims)
254 self._attrs = dict(attrs) if attrs else None
File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/namedarray/core.py:481, in NamedArray._parse_dimensions(self, dims)
479 dims = (dims,) if isinstance(dims, str) else tuple(dims)
480 if len(dims) != self.ndim:
--> 481 raise ValueError(
482 f"dimensions {dims} must have the same length as the "
483 f"number of data dimensions, ndim={self.ndim}"
484 )
485 if len(set(dims)) < len(dims):
486 repeated_dims = set([d for d in dims if dims.count(d) > 1])
ValueError: dimensions ('dim_0',) must have the same length as the number of data dimensions, ndim=0
If the xarray is created inside a JITted function, things seem to work:
def make_xr(a):
return xarray_jax.DataArray(a)
def compose(a):
return (ident(make_xr(a)))
print(jax.jit(compose).lower(foo).compile()(foo)) # Works
# <xarray.DataArray (dim_0: 3)>
# xarray_jax.JaxArrayWrapper(Array([1., 1., 1.], dtype=float32))
# Dimensions without coordinates: dim_0
I'm not yet sure if exploding xarray arguments into a more pytree-friendly version only to recreate them inside a wrapper is a generic solution, or if doing so with graphcast would just reveal an error further in.
Self-plug here, but there is a separate xarray_jax package on PyPi now, and I just tested and it works :)
https://github.com/allen-adastra/xarray_jax
def ident(a): # Trivial test function
return a
# Sample variables
foo = jnp.ones(3)
foo_xr = xr.DataArray(foo)
print(jax.jit(ident)(foo)) # Works
# [1. 1. 1.]
print(jax.jit(ident)(foo_xr)) # Works
# <xarray.DataArray (dim_0: 3)>
# xarray_jax.JaxArrayWrapper(Array([1., 1., 1.], dtype=float32))
# Dimensions without coordinates: dim_0
jax.jit(ident).lower(foo) # Works
# <jax._src.stages.Lowered at 0x151bb5e04830>
jax.jit(ident).lower(foo_xr) # Fails
That's potentially useful, how is your package's compatibility with graphcast?
I don't know. It does handle things differently.
Hi all, this is now fixed in the graphcast version. To fix it one needs to be able to wrap jax.stages.ArgInfo. See https://github.com/google-deepmind/graphcast/blob/main/graphcast/xarray_jax_test.py#L144