jax icon indicating copy to clipboard operation
jax copied to clipboard

pure_callback is broken with multiple vmap

Open Joshuaalbert opened this issue 1 year ago • 7 comments

Description

When vectorized=True the expectation is that the callback of pure_callback should vectorise over common leading batch dims. That is, all batch dims of any mapped array should be identical, with shape broadcasting performed on JAX-side. If an array has not been mapped then it should not receive a batch dim. If this is violated then it is impossible for the callback to construct the proper output shape.

from functools import partial
import jax
import jax.numpy as jnp


@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def add_vmapped(x, y, z):
    return x + y + z


@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_no_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == ()
        assert y.shape == ()
        assert z.shape == ()
        return x + y + z

    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=False)


@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5)
        assert y.shape == (4, 5)
        assert z.shape == ()
        return x + y + z

    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=True)


if __name__ == '__main__':
    x = jnp.arange(4, dtype=jnp.float32)
    y = jnp.arange(5, dtype=jnp.float32)
    z = jnp.array(1, dtype=jnp.float32)

    assert add_vmapped(x, y, z).shape == (4, 5)
    assert cb_no_vec(x, y, z).shape == (4, 5)
    assert cb_vec(x, y, z).shape == (4, 5)

System info (python version, jaxlib version, accelerator, etc.)

jax==0.4.31
jaxlib==0.4.31

Joshuaalbert avatar Sep 13 '24 12:09 Joshuaalbert

Similar to https://github.com/google/jax/issues/17187, not sure I follow the logic of this comment

Joshuaalbert avatar Sep 13 '24 12:09 Joshuaalbert

That is, all batch dims of any mapped array should be identical, with shape broadcasting performed on JAX-side.

This actually isn't the behavior of vectorized! I know that the way it's presented in the docs is confusing, and I'm actually pushing to deprecate the vectorized behavior in favor of a more expressive API. I think that what you want is something like a "broadcasting vmap", which can be built using custom_vmap. Something like the following should do the trick:

def broadcasting_vmap(f):
  f = jax.custom_batching.custom_vmap(f)

  @f.def_vmap
  def rule(axis_size, in_batched, *args):
    batched_args = jax.tree.map(
        lambda x, b: x if b else jax.lax.broadcast(x, (axis_size,)), args,
        tuple(in_batched))
    out = f(*batched_args)
    out_batched = jax.tree.map(lambda _: True, out)
    return out, out_batched

  return f

dfm avatar Sep 13 '24 12:09 dfm

It might be just the trick. However, can I suggest you make sure it pass this?

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5)
        assert y.shape == (4, 5)
        assert z.shape == ()
        return x + y + z
    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=True)


if __name__ == '__main__':
    x = jnp.arange(4, dtype=jnp.float32)
    y = jnp.arange(5, dtype=jnp.float32)
    z = jnp.array(1, dtype=jnp.float32)

    assert cb_vec(x, y, z).shape == (4, 5)

Joshuaalbert avatar Sep 13 '24 14:09 Joshuaalbert

Are you sure you want assert z.shape == ()? My suggestion was that you write:

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
@broadcasting_vmap  # <--------------------------- HERE
def cb_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5)
        assert y.shape == (4, 5)
        assert z.shape == (4, 5)
        return x + y + z
    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z)

dfm avatar Sep 13 '24 14:09 dfm

The problem is that z should be a scalar inside the func, not broadcasted. Note this is not ufunc behaviour but is what I am looking for. Mapped args are broadcasted. Unmapped are not.On Sept 13, 2024 16:21, Dan Foreman-Mackey @.***> wrote: Are you sure you want assert z.shape == ()? My suggestion was that you write: @partial(jax.vmap, in_axes=(0, None, None)) @partial(jax.vmap, in_axes=(None, 0, None)) @broadcasting_vmap # <--------------------------- HERE def cb_vec(x, y, z): def add(x, y, z): assert x.shape == (4, 5) assert y.shape == (4, 5) assert z.shape == (4, 5) return x + y + z return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z)

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>

Joshuaalbert avatar Sep 13 '24 14:09 Joshuaalbert

I don't think there's any good way to get that behavior. The inner vmap doesn't "know" about the outer one so I expect you'll be hard pressed to come up with consistent logic to end up with z a scalar. One thing you probably could get would be to get shapes (4, 1), (1, 5), and (1, 1) if that's better for your use case:

A possible implementation
def joshuaalbert_vmap(f):
  f = jax.custom_batching.custom_vmap(f)

  @f.def_vmap
  def rule(axis_size, in_batched, *args):
    batched_args = jax.tree.map(
        lambda x, b: x if b else jax.lax.broadcast(x, (1,)), args,  # <- 1 instead of axis_size
        tuple(in_batched))
    out = f(*batched_args)
    out_batched = jax.tree.map(lambda _: True, out)
    return out, out_batched

  return f

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
@joshuaalbert_vmap
def cb_broadcasting(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 1)
        assert y.shape == (1, 5)
        assert z.shape == (1, 1)
        return x + y + z
    out_shape = jnp.broadcast_shapes(x.shape, y.shape, z.shape)  # <-- note here
    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=out_shape, dtype=x.dtype), x, y, z)

The issue is that there needs to be some logic for which arguments to broadcast in each vmap and that can't depend on whether or not an argument is going to be mapped in the future. "vectorized" handles this by never mapping anything that isn't mapped, and I think that it's unlikely that we could come up with sensible logic to get exactly what you're asking for here. All that to say, I do think that you might be able to come up with something that works for your use case using custom_vmap and maybe that will help clarifying your feature request.

dfm avatar Sep 13 '24 14:09 dfm

I understand the constraint. Hmm, perhaps there is another middle ground. In principle, if an argument should never be broadcasted, then it can be curried. The remaining args then can receive broadcasting to convert the function into a ufunc style func. I think in effort to make the API clear, you might merge both above broadcast choices, and rename to convert_to_ufunc with a tile boolean which determines if the array shapes should broadcasted beforehand.

def convert_to_ufunc(f, tile: bool = True):
    f = jax.custom_batching.custom_vmap(f)

    @f.def_vmap
    def rule(axis_size, in_batched, *args):
        batched_args = jax.tree.map(
            lambda x, b: x if b else jax.lax.broadcast(x, ((axis_size if tile else 1),)), args,
            tuple(in_batched))
        out = f(*batched_args)
        out_batched = jax.tree.map(lambda _: True, out)
        return out, out_batched

    return f

def cb(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5) # if tile=False
        assert y.shape == (4, 5) # if tile=False
        assert z.shape == ()
        return x + y + z

    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=jnp.broadcast_shapes(x.shape, y.shape), dtype=x.dtype), x,
                             y, z, vectorized=True)

# Curry z first
assert jax.vmap(jax.vmap(convert_to_ufunc(partial(cb, z=z)), in_axes=(None, 0)), in_axes=(0, None))(x, y).shape == (4, 5)

With this setup the original intent of this issue is resolved, i.e. we can now trust that applying vmap multiple times gives consistent shapes inside the callback, which allows easier reasoning.

Joshuaalbert avatar Sep 14 '24 10:09 Joshuaalbert

I'm going to close this now that pure_callback has a vmap_method parameter and the vectorized parameter is deprecated. Please let me know if there's something else that was missed from this discussion!

dfm avatar Jan 02 '25 17:01 dfm