graphcast icon indicating copy to clipboard operation
graphcast copied to clipboard

xarray_jax does not support jax.jit().lower

Open csubich opened this issue 1 year ago • 3 comments

The JAX API now includes more detailed control over the compilation process with jax.stages, but the xarray_jax wrapper here in graphcast does not seem to support jax.jit().lower:

import graphcast.xarray_jax as xarray_jax
import jax.numpy as jnp
import jax

def ident(a): # Trivial test function
    return a

# Sample variables
foo = jnp.ones(3)
foo_xr = xarray_jax.DataArray(foo)

print(jax.jit(ident)(foo)) # Works
# [1. 1. 1.]

print(jax.jit(ident)(foo_xr)) # Works
# <xarray.DataArray (dim_0: 3)>
# xarray_jax.JaxArrayWrapper(Array([1., 1., 1.], dtype=float32))
# Dimensions without coordinates: dim_0

jax.jit(ident).lower(foo) # Works
# <jax._src.stages.Lowered at 0x151bb5e04830>

jax.jit(ident).lower(foo_xr) # Fails
Traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 1
----> 1 jax.jit(ident).lower(foo_xr) # Fails

    [... skipping hidden 5 frame]

File /fs/site5/eccc/mrd/rpnatm/csu001/ppp5/graphcast_dev/graphcast/xarray_jax.py:668, in _unflatten_variable(aux, children)
    666 dims_change_fn = _DIMS_CHANGE_ON_UNFLATTEN_FN.get(None)
    667 if dims_change_fn: dims = dims_change_fn(dims)
--> 668 return Variable(dims=dims, data=children[0])

File /fs/site5/eccc/mrd/rpnatm/csu001/ppp5/graphcast_dev/graphcast/xarray_jax.py:113, in Variable(dims, data, **kwargs)
    111 def Variable(dims, data, **kwargs) -> xarray.Variable:  # pylint:disable=invalid-name
    112   """Like xarray.Variable, but can wrap JAX arrays."""
--> 113   return xarray.Variable(dims, wrap(data), **kwargs)

File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/core/variable.py:365, in Variable.__init__(self, dims, data, attrs, encoding, fastpath)
    338 def __init__(
    339     self,
    340     dims,
   (...)
    344     fastpath=False,
    345 ):
    346     """
    347     Parameters
    348     ----------
   (...)
    363         unrecognized encoding items.
    364     """
--> 365     super().__init__(
    366         dims=dims, data=as_compatible_data(data, fastpath=fastpath), attrs=attrs
    367     )
    369     self._encoding = None
    370     if encoding is not None:

File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/namedarray/core.py:253, in NamedArray.__init__(self, dims, data, attrs)
    246 def __init__(
    247     self,
    248     dims: _DimsLike,
    249     data: duckarray[Any, _DType_co],
    250     attrs: _AttrsLike = None,
    251 ):
    252     self._data = data
--> 253     self._dims = self._parse_dimensions(dims)
    254     self._attrs = dict(attrs) if attrs else None

File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/namedarray/core.py:481, in NamedArray._parse_dimensions(self, dims)
    479 dims = (dims,) if isinstance(dims, str) else tuple(dims)
    480 if len(dims) != self.ndim:
--> 481     raise ValueError(
    482         f"dimensions {dims} must have the same length as the "
    483         f"number of data dimensions, ndim={self.ndim}"
    484     )
    485 if len(set(dims)) < len(dims):
    486     repeated_dims = set([d for d in dims if dims.count(d) > 1])

ValueError: dimensions ('dim_0',) must have the same length as the number of data dimensions, ndim=0

If the xarray is created inside a JITted function, things seem to work:

def make_xr(a):
    return xarray_jax.DataArray(a)

def compose(a):
    return (ident(make_xr(a)))

print(jax.jit(compose).lower(foo).compile()(foo)) # Works
# <xarray.DataArray (dim_0: 3)>
# xarray_jax.JaxArrayWrapper(Array([1., 1., 1.], dtype=float32))
# Dimensions without coordinates: dim_0

I'm not yet sure if exploding xarray arguments into a more pytree-friendly version only to recreate them inside a wrapper is a generic solution, or if doing so with graphcast would just reveal an error further in.

csubich avatar Sep 23 '24 16:09 csubich

Self-plug here, but there is a separate xarray_jax package on PyPi now, and I just tested and it works :) https://github.com/allen-adastra/xarray_jax

    def ident(a):  # Trivial test function
        return a

    # Sample variables
    foo = jnp.ones(3)
    foo_xr = xr.DataArray(foo)

    print(jax.jit(ident)(foo))  # Works
    # [1. 1. 1.]

    print(jax.jit(ident)(foo_xr))  # Works
    # <xarray.DataArray (dim_0: 3)>
    # xarray_jax.JaxArrayWrapper(Array([1., 1., 1.], dtype=float32))
    # Dimensions without coordinates: dim_0

    jax.jit(ident).lower(foo)  # Works
    # <jax._src.stages.Lowered at 0x151bb5e04830>

    jax.jit(ident).lower(foo_xr)  # Fails

allen-adastra avatar Oct 16 '24 14:10 allen-adastra

That's potentially useful, how is your package's compatibility with graphcast?

csubich avatar Oct 16 '24 14:10 csubich

I don't know. It does handle things differently.

allen-adastra avatar Oct 16 '24 14:10 allen-adastra

Hi all, this is now fixed in the graphcast version. To fix it one needs to be able to wrap jax.stages.ArgInfo. See https://github.com/google-deepmind/graphcast/blob/main/graphcast/xarray_jax_test.py#L144

mjwillson avatar Jul 02 '25 09:07 mjwillson