jax icon indicating copy to clipboard operation
jax copied to clipboard

map_coordinates mode='constant' not properly applied

Open agudallago opened this issue 4 years ago • 6 comments

jax.scipy.ndimage.map_coordinates behaves differently than scipy.ndimage.map_coordinates when using mode="constant". In this example we can see that 2.9 which should be as "off the array" as 3, produces a different output which is in fact wrong (1) but should in fact produce what the original implementation produces (0, which is cval).

from scipy import ndimage
import jax

import jax.numpy as jnp

src = jnp.array([5, 1, 6])  # Fails!
# src = jnp.array([5, 1, 5]) # Doesn't fail

# Since 2.9 is first jnp.floored, this gives "2". Which is a valid index.
coords = [[2, 3, 2.9]]


def scipy_map_coordinates():
  return ndimage.map_coordinates(src, coords, order=1, mode="constant")


def jax_map_coordinates():
  return jax.scipy.ndimage.map_coordinates(
      src, coords, order=1, mode="constant")


print(scipy_map_coordinates())
print(jax_map_coordinates())

assert jnp.array_equal(scipy_map_coordinates(), jax_map_coordinates())

Yielding:

[6 0 0]
[6 0 1]
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-6-b13f0479754d> in <module>()
     20 print(jax_map_coordinates())
     21 
---> 22 assert jnp.array_equal(scipy_map_coordinates(), jax_map_coordinates())

AssertionError: 

CC: @claudiofantacci

agudallago avatar Feb 09 '21 12:02 agudallago

5,1,5 doesn't fail cause the 5*weight gets internally rounded down to zero... matching "by luck" the original implementation output. The internal tests do a weird padding and refactor the input from [5,1,6] to [5,1,6,0] which doesn't seem to make sense with the expected behaviour.

agudallago avatar Feb 09 '21 16:02 agudallago

I think this relates to some of the discreptancies related to boundary conditions that we noticed in the behavior of SciPy's map_coordinates: https://github.com/scipy/scipy/issues/2640

The SciPy issue is marked as closed by https://github.com/scipy/scipy/pull/12776. It sounds like the decision was to add a new mode grid-constant rather than to attempt to fix constant. I'm not sure that's the choice I would have made, but I guess it is probably the right decision for us, too.

shoyer avatar Feb 10 '21 01:02 shoyer

Hi @shoyer, thanks for your answer!

I apologize for my description, it may have not been as clear as it sounded in my mind.

From my experiments it seems to me that the jax version of map_coordinates is equivalent to grid-constant in SciPy 1.6. i.e. jax version is interpolating even outside the edges. See https://docs.scipy.org/doc/scipy/reference/tutorial/ndimage.html#ndimage-interpolation-modes for more details of what I mean in order=1 and mode="grid_constant"

src = jnp.array([5,1,6])
coords = [[3]]

# This returns src[3] -> cval=0 -->>> OK!
print(jax.scipy.ndimage.map_coordinates(src, coords, order=1, mode="constant"))

# This returns src[2.9] != 0 -->>> NOT OK for "constant", OK for "grid-constant"
coords = [[2.9]]
print(jax.scipy.ndimage.map_coordinates(src, coords, order=1, mode="constant"))

So in order to be consistent with scipy and also agree with the documentation, I think we should both:

  • rename constant to grid_constant
  • (optional) implement constant as SciPy does.

What do you think? Is this what you also meant in other words? For what is worth, in this particular mode and situation, I don't think that Scipy's output is wrong.

agudallago avatar Feb 10 '21 17:02 agudallago

So in order to be consistent with scipy and also agree with the documentation, I think we should both:

  • rename constant to grid_constant
  • (optional) implement constant as SciPy does.

What do you think? Is this what you also meant in other words?

Yes, exactly!

For what is worth, in this particular mode and situation, I don't think that Scipy's output is wrong.

I believe SciPy has updated its documentation for what mode='constant' means, alongside the introduction of mode='grid-constant'.

shoyer avatar Feb 10 '21 17:02 shoyer

Hi @agudallago,

rename constant to grid_constant

Documentation of jax.scipy.ndimage.map_coordinates now mentions that mode='constant' in JAX behaves as mode='grid-constant' in SciPy. Please check the doccumentation here: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.ndimage.map_coordinates.html

Thank you.

rajasekharporeddy avatar Jun 21 '24 13:06 rajasekharporeddy

We should update the implementation to match scipy now that they've fixed the issue on their end.

jakevdp avatar Jun 21 '24 14:06 jakevdp