map_coordinates mode='constant' not properly applied
jax.scipy.ndimage.map_coordinates behaves differently than scipy.ndimage.map_coordinates when using mode="constant". In this example we can see that 2.9 which should be as "off the array" as 3, produces a different output which is in fact wrong (1) but should in fact produce what the original implementation produces (0, which is cval).
from scipy import ndimage
import jax
import jax.numpy as jnp
src = jnp.array([5, 1, 6]) # Fails!
# src = jnp.array([5, 1, 5]) # Doesn't fail
# Since 2.9 is first jnp.floored, this gives "2". Which is a valid index.
coords = [[2, 3, 2.9]]
def scipy_map_coordinates():
return ndimage.map_coordinates(src, coords, order=1, mode="constant")
def jax_map_coordinates():
return jax.scipy.ndimage.map_coordinates(
src, coords, order=1, mode="constant")
print(scipy_map_coordinates())
print(jax_map_coordinates())
assert jnp.array_equal(scipy_map_coordinates(), jax_map_coordinates())
Yielding:
[6 0 0]
[6 0 1]
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-6-b13f0479754d> in <module>()
20 print(jax_map_coordinates())
21
---> 22 assert jnp.array_equal(scipy_map_coordinates(), jax_map_coordinates())
AssertionError:
CC: @claudiofantacci
5,1,5 doesn't fail cause the 5*weight gets internally rounded down to zero... matching "by luck" the original implementation output. The internal tests do a weird padding and refactor the input from [5,1,6] to [5,1,6,0] which doesn't seem to make sense with the expected behaviour.
I think this relates to some of the discreptancies related to boundary conditions that we noticed in the behavior of SciPy's map_coordinates: https://github.com/scipy/scipy/issues/2640
The SciPy issue is marked as closed by https://github.com/scipy/scipy/pull/12776. It sounds like the decision was to add a new mode grid-constant rather than to attempt to fix constant. I'm not sure that's the choice I would have made, but I guess it is probably the right decision for us, too.
Hi @shoyer, thanks for your answer!
I apologize for my description, it may have not been as clear as it sounded in my mind.
From my experiments it seems to me that the jax version of map_coordinates is equivalent to grid-constant in SciPy 1.6.
i.e. jax version is interpolating even outside the edges.
See https://docs.scipy.org/doc/scipy/reference/tutorial/ndimage.html#ndimage-interpolation-modes for more details of what I mean in order=1 and mode="grid_constant"
src = jnp.array([5,1,6])
coords = [[3]]
# This returns src[3] -> cval=0 -->>> OK!
print(jax.scipy.ndimage.map_coordinates(src, coords, order=1, mode="constant"))
# This returns src[2.9] != 0 -->>> NOT OK for "constant", OK for "grid-constant"
coords = [[2.9]]
print(jax.scipy.ndimage.map_coordinates(src, coords, order=1, mode="constant"))
So in order to be consistent with scipy and also agree with the documentation, I think we should both:
- rename
constanttogrid_constant - (optional) implement
constantas SciPy does.
What do you think? Is this what you also meant in other words? For what is worth, in this particular mode and situation, I don't think that Scipy's output is wrong.
So in order to be consistent with scipy and also agree with the documentation, I think we should both:
- rename
constanttogrid_constant- (optional) implement
constantas SciPy does.What do you think? Is this what you also meant in other words?
Yes, exactly!
For what is worth, in this particular mode and situation, I don't think that Scipy's output is wrong.
I believe SciPy has updated its documentation for what mode='constant' means, alongside the introduction of mode='grid-constant'.
Hi @agudallago,
rename constant to grid_constant
Documentation of jax.scipy.ndimage.map_coordinates now mentions that mode='constant' in JAX behaves as mode='grid-constant' in SciPy. Please check the doccumentation here: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.ndimage.map_coordinates.html
Thank you.
We should update the implementation to match scipy now that they've fixed the issue on their end.