Better support for symbolic (e.g. jax tracer) values in non-index coordinates
Is your feature request related to a problem?
Hello! xarray_jax maintainer here.
I am trying to decide whether it is feasible to rely on using symbolic arrays (such as jax tracers used during JIT compilation) as coordinate data in xarray.
For index coordinates this doesn't work at all, because xarray requires dynamic access to the data in the index to do alignment. So our approach in xarray_jax is to treat index coordinates as static data. I think this is acceptable: even if we could do index alignment with tracers, the result would be variable-shape which jax.jit can't handle without re-tracing anyway.
For non-index coordinates though, it is possible to use jax tracers! However these coordinates have a habit of disappearing due to failed alignment checks, for example in arithmetic operations:
import xarray as xr
import jax
import jax.numpy as jnp
def foo(coord):
a = xr.DataArray(data=jnp.ones(3), dims=['x'], coords={'foo': (['x'], coord)})
b = xr.DataArray(data=jnp.ones(3), dims=['x'], coords={'foo': (['x'], coord)})
print((a + b).coords)
foo(jnp.arange(3))
# prints:
# Coordinates:
# foo (x) int32 12B ...
jax.jit(foo)(jnp.arange(3))
# prints:
# Coordinates:
# *empty*
The coordinate is (correctly) preserved when not JIT-ing, because xarray is able to do an equality comparison to check that the coordinate is compatible.
When JITing however, the coordinate array will be a jax Tracer, and the equality comparison will fail with a jax.errors.TracerBoolConversionError (or similar). xarray catches this under the hood here since TracerBoolConversionError subclasses TypeError. xarray then treats the failed comparison as not-equal resulting in the coordinate being dropped.
This isn't entirely consistent with behaviour elsewhere in xarray.
For example when doing an explicit merge it is possible to set compat='override' which ensures that coordinates aren't tested for equality and so jax tracers survive. And I believe compat='override' is to become the default soon.
xr.align also lets jax Tracers survive:
def foo(coord):
a = xr.DataArray(data=jnp.ones(3), dims=['x'], coords={'foo': (['x'], coord)})
b = xr.DataArray(data=jnp.ones(3), dims=['x'], coords={'foo': (['x'], coord)})
a, b = xr.align(a, b)
print(a.coords)
print(b.coords)
jax.jit(foo)(jnp.arange(3))
# prints:
# Coordinates:
# foo (x) int32 12B ...
# Coordinates:
# foo (x) int32 12B ...
Overall, while I can see why some might want them, these compatibility checks on non-index coordinates seem a bit of an awkward niche behaviour which is configurable in some places but not others, and (at least after the compat='override' default takes over) will have different defaults in different places too IIUC, which is not ideal. And of course they throw a bit of a spanner in the works of using jax here.
Describe the solution you'd like
Three suggestions:
-
Add a
xarray.set_options(arithmetic_compat='override')setting which controls this, and make 'override' the default for it at the same timecompat='override'becomes the default more widely. Note there is already anarithmetic_joinsetting which controls how index coordinates are joined/aligned, but this doesn't give any control over these compatibility checks on non-index coordinates. One might also consider exposing thecompatoption onxr.aligntoo and any other places where compatibility checks are happening or for consistency ought to happen, to give control over this. -
Allow
jax.errors.TracerBoolConversionErrorto bubble up from generic equality-testing code (or perhaps catch and reraise it as something less jax-specific likexr.SymbolicComparisonError) rather than catching it and treating it as unequal. Then in coordinate compatibility-checking code, catch this error and fall back oncompat='override'behavior in this specific case. This has the advantage that non-index-coord compatibility checks don't have to be disabled across the board, only for jax Tracers. It would require some jax-specific special case in xarray though. -
Just to get rid of alignment checks on non-index coordinates entirely, if noone feels strongly about them. They are already somewhat inconsistent both in configurability and in default behaviour. In many cases it's a safe assumption that any non-index coordinates are a fixed function of the index coordinates and so it's sufficient to check for alignment of the index coordinates. And perhaps other cases it would be OK to make any checking the user's responsibility. I realise this would be backwards-incompatible though.
Describe alternatives you've considered
See above.
Additional context
No response
I haven't been involved much with this code for a while. But would love to support jax_xarray and think it's worth some modest tradeoffs in order to do that.
I think (3) is interesting but I would suggest we avoid coupling — would be great to let jax_xarray iterate without waiting for broad agreements about backward-incompat changes.
Would (1) be the least invasive?
👍 for doing (1) now, and then considering (3). In the long term, it would be nice if nothing in Xarray required computing of non-index array values.
Thanks both. Yes I think (1) would be less invasive then (3). (2) hopefully not too invasive either but would require a small amount of jax-specific code, something like:
SYMBOLIC_COMPARISON_ERRORS = ()
try:
import jax
SYMBOLIC_COMPARISON_ERRORS += (jax.errors.ConcretizationTypeError,)
except ImportError:
pass
I think even if implementing (1), it would be better to let errors like TracerBoolConversionError bubble up rather than silently treating them as not-equals in compatibility checks, which can lead to different results from the same code depending whether it's JIT-ed or not. Perhaps the exception would get reraised wrapped in an xarray error which tells you that you can avoid the comparison by enabling arithmetic_compat='override' or compat='override'? Admittedly this is less of an issue if arithmetic_compat='override' is the default though.
Anyway I can have a go at (1) if I get time, although if anyone familiar with the many codepaths relating to compat checks in xarray fancies it (or can give some pointers) that would also be great :)
Anyway I can have a go at (1) if I get time, although if anyone familiar with the many codepaths relating to compat checks in xarray fancies it (or can give some pointers) that would also be great :)
I would search for _binary_op and specify compat in align. See https://github.com/pydata/xarray/blob/eb01d9c60ffacb0c433e7d439973f869b8fd81b3/xarray/core/dataarray.py#L4891-L4893 for an example
the _binary_op method of DataTree, Dataset and DataArray is the starting point, but it's this line that does the comparison of coordinates without indexes: https://github.com/pydata/xarray/blob/eb01d9c60ffacb0c433e7d439973f869b8fd81b3/xarray/core/dataarray.py#L4902
We'd need to pass arithmetic_compat to _merge_raw (needs a new compat parameter) and then make sure it is passed to merge_coordinates_without_align (which also doesn't have the compat parameter yet) and merge_collected
Thanks for the pointers @keewis !
@dcherian
I would search for _binary_op and specify compat in align
There is actually no compat argument to align, in fact align doesn't align non-index coordinates at all, example:
In [11]: a = xr.DataArray([1,2,3], dims=['x'], coords={'foo': (['x'], [1,2,3])})
In [12]: b = xr.DataArray([1,2,3], dims=['x'], coords={'foo': (['x'], [4,5,6])})
In [13]: xr.align(a, b)
Out[13]:
(<xarray.DataArray (x: 3)> Size: 24B
array([1, 2, 3])
Coordinates:
foo (x) int64 24B 1 2 3
Dimensions without coordinates: x,
<xarray.DataArray (x: 3)> Size: 24B
array([1, 2, 3])
Coordinates:
foo (x) int64 24B 4 5 6
Dimensions without coordinates: x)
This isn't a problem for xarray_jax (in fact it makes its life easier -- the less code that relies on evaluating non-index coordinates the easier it'll be to use jax arrays for them) but it is one of the inconsistencies I mentioned above.
it is one of the inconsistencies I mentioned above
I wouldn't see this as a inconsistency (if I understand correctly what you're referring to): the purpose of align is just to make sure that the indexes of all arguments are the same. For that it only compares indexes, and the non-index coordinates are treated like data variables.
The binary operations need to align first and compute the actual operation, but when assembling the result also need to merge the two sets of coordinates, and that's where we're comparing the non-index coordinates.
OK, yeah that's fair enough, I guess I wasn't quite clear on the intention behind align.
OK so I got an arithmetic_compat working in the above PR.
One issue with moving away from arithmetic_compat='minimal' (the setting which silently drops clashing coordinates) is that it's quite common to get left with clashing scalar coordinates in cases like arr[0] + arr[1], and in that case you'd usually want the coordinate dropped, as mentioned here:
https://docs.xarray.dev/en/stable/user-guide/computation.html#coordinates
One can always explicitly drop the coordinates in cases like this though (e.g. using .isel with drop=True) and I'd be a bit loathe to special-case scalar coordinates. Any preference what to do here?
One issue with moving away from
arithmetic_compat='minimal'(the setting which silently drops clashing coordinates) is that it's quite common to get left with clashing scalar coordinates in cases likearr[0] + arr[1], and in that case you'd usually want the coordinate dropped, as mentioned here:
Indeed, now that I recall this behavior, I'm pretty sure what's why the default compat mode for arithmetic was minimal originally!
This is trickier than I thought. For now, I would just add the new configuration option.
For now, I would just add the new configuration option.
OK I've updated #10943 to just do this.
Any thoughts on whether we should let materialization errors like TracerBoolConversionError bubble up from DataArray.equals etc, rather than treating it as not-equal which can lead to different behaviour in JIT'd vs non-JIT'd contexts?
I'm not 100% sure why we catch all exceptions here. I suspect they exist at least to handle types where equality can fail with ValueError/TypeError. You could try relaxing to only catch those and seeing if that turns up any test failures.
For reference, these are handled within the equals, broadcast_equals and identical methods on Variable:
https://github.com/pydata/xarray/blob/6e82a3afa8e47e7ed59441e77f812ceeaeaf5668/xarray/core/variable.py#L1865-L1881
This treats TypeError and AttributeError as "not equal", and because TracerBoolConversionError inherits from TypeError (IIUC), we're catching that, as well. I'm not sure how we'd get this to bubble up besides special-casing jax or restricting to exactly TypeError, none of which seems great.
I think something along these lines wouldn't be the end of the world:
SYMBOLIC_COMPARISON_ERRORS = ()
try:
import jax
SYMBOLIC_COMPARISON_ERRORS += (jax.errors.ConcretizationTypeError,)
except ImportError:
pass
...
try:
...
except SYMBOLIC_COMPARISON_ERRORS as e:
raise ValueError("Equality comparison of two xarray objects failed due to ...") from e
except (TypeError, AttributeError):
return False
In that the jax-specific bit is quite contained, handles the missing dependency and there's nothing stopping folks adding similar lines for other libraries with similar errors in future.
Or better yet make the except (TypeError, AttributeError): less broad if possible, but without knowing what it's there to address it's hard to say what this might break or what it should be replaced with.
Ideally I feel like the only errors which are swallowed here should be ones which are known to be raised in a very consistent and fixed set of scenarios, not things which might sometimes be raised sometimes not, leading to downstream behaviours like coordinates silently disappearing in one situation but not another.
At any rate maybe we could start with the arithmetic_compat thing for now -- if anyone's able to take a look at the PR (#10943) that'd be brilliant :)