jax icon indicating copy to clipboard operation
jax copied to clipboard

`vmap` of `pure_callback(..., vectorized=True)` fails on broadcasts

Open Gattocrucco opened this issue 1 year ago • 10 comments

Description

The following code defines a function to sum a (2, 3) to a (3,) array, which are obtained internally with reshapes instead of being passed as arguments. If the sum is performed with pure_callback(..., vectorized=True) and the function is vmapped, this fails because the batching rule of pure_callback does not broadcast the two arguments to (..., 2, 3), (..., 1, 3), instead shoving them directly into the callback as (..., 2, 3), (..., 3).

The problem can be bypassed by broadcasting the arrays beforehand, but this is not efficient in general, since many redundant operations and memory could be used if the input arrays have size larger than necessary. (In my experience this can be a serious bottleneck).

import jax
from jax import numpy as jnp

def f(x, use_callback, vectorized, broadcast):
    a = x[0:6].reshape(2, 3)
    b = x[6:9].reshape(3)
    if use_callback:
        class result:
            dtype = x.dtype
            shape = (2, 3)
        if broadcast:
            a, b = jnp.broadcast_arrays(a, b)
        return jax.pure_callback(
            lambda a, b: a + b,
            result, a, b,
            vectorized=vectorized,
        )
    else:
        return a + b

f_vectorize = jnp.vectorize(f, signature='(9)->(2,3)', excluded=(1, 2, 3))
f_vmap = jax.vmap(f, (0, None, None, None))

def tryout(task, *args):
    try:
        return task(*args)
    except Exception as exc:
        print(f'{task.__name__}({", ".join(map(str, args))}) fails with:')
        print(exc.__class__.__name__, exc.args[0].split('\n')[0])
        return exc

tryout(f, jnp.ones(9), False, None, None)
tryout(f_vmap, jnp.ones((5, 9)), False, None, None)
tryout(f_vectorize, jnp.ones((5, 9)), False, None, None)
tryout(f, jnp.ones(9), True, True, False)
tryout(f_vmap, jnp.ones((5, 9)), True, False, False)
tryout(f_vectorize, jnp.ones((5, 9)), True, False, False)
tryout(f_vmap, jnp.ones((5, 9)), True, True, False)      #  FAILS
tryout(f_vectorize, jnp.ones((5, 9)), True, True, False) #  FAILS
tryout(f_vmap, jnp.ones((5, 9)), True, True, True)      #  SUCCEEDS
tryout(f_vectorize, jnp.ones((5, 9)), True, True, True) #  SUCCEEDS

Output:

f([[1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1.]], True, True, False) fails with:
XlaRuntimeError INTERNAL: Generated function failed: CpuCallback error: ValueError: operands could not be broadcast together with shapes (5,2,3) (5,3) 
f([[1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1.]], True, True, False) fails with:
XlaRuntimeError INTERNAL: Generated function failed: CpuCallback error: ValueError: operands could not be broadcast together with shapes (5,2,3) (5,3) 

What jax/jaxlib version are you using?

jax and jaxlib 0.4.14

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.11.2 (v3.11.2:878ead1ac1, Feb 7 2023, 10:02:41) [Clang 13.0.0 (clang-1300.0.29.30)], macOS 13.4

NVIDIA GPU info

No response

Gattocrucco avatar Aug 18 '23 21:08 Gattocrucco

Sorry, I'm having trouble working out what your functions are actually doing (long day...) could you show an example of a simple call that fails that you think should not fail, without the indirection of all the flags and the tryout function?

jakevdp avatar Aug 18 '23 21:08 jakevdp

Ok, sorry. The minimal example is:

@jax.vmap
def f(x):
    a = x[0:6].reshape(2, 3)
    b = x[6:9].reshape(3)
    result = jax.ShapeDtypeStruct((2, 3), x.dtype)
    return jax.pure_callback(jnp.add, result, a, b, vectorized=True)

f(jnp.ones((5,9)))
XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: Incompatible shapes for broadcasting: shapes=[(5, 2, 3), (5, 3)]

Gattocrucco avatar Aug 18 '23 21:08 Gattocrucco

Thanks - edited to make the result specification more canonical.

@sharadmv, can you take a look?

jakevdp avatar Aug 18 '23 21:08 jakevdp

On second thought, the behavior I expect requires pure_callback to treat callback as an ufunc (vs. a generalized ufunc). This is not stated in the documentation. However, since with vectorize=False it works (see complete example in the first message), there is at least an inconsistency.

My personal preference would be that callback is treated as ufunc, and in the future pure_callback could be extended with a signature parameter to batch correctly on generalized ufuncs.

Gattocrucco avatar Aug 18 '23 21:08 Gattocrucco

If someone encounters this same issue, in the meantime here is the workaround I'm going to use:

def pure_callback_ufunc(callback, dtype, *args, **kwargs):
    shape = jnp.broadcast_shapes(*(arg.shape for arg in args))
    ndim = len(shape)
    padded_args = [
        jnp.expand_dims(a, tuple(range(ndim - a.ndim)))
        for a in args
    ]
    result = jax.ShapeDtypeStruct(shape, dtype)
    return jax.pure_callback(callback, result, *padded_args, vectorized=True, **kwargs)

@jax.vmap
def f(x):
    a = x[0:6].reshape(2, 3)
    b = x[6:9].reshape(3)
    return pure_callback_ufunc(np.add, x.dtype, a, b)

f(jnp.ones((5,9)))

Gattocrucco avatar Aug 18 '23 22:08 Gattocrucco

This is expected behaviour for pure_callback.

With vectorized=False, the callback is applied separately to each batch element, for which the shapes (2,3) and (3,) are broadcastable. With vectorized=True, the callback is invoked only once, with shapes (5,2,3) and (5,3), which are not broadcastable.

ufunc behaviour cannot be the default, as other users may wish to use pure_callback for functions which are not ufuncs.

If using vectorized=True then it is the programmer's responsibility to ensure that the callback can accept the additional batch dimensions.

patrick-kidger avatar Aug 21 '23 07:08 patrick-kidger

I know this is the current behavior as defined, but I disagree it is "expected" by the user (me, in particular). It is surprising. It surely breaks ufuncs, which is a very common kind of function.

I understand it is impossible to do this in the right way automatically, so I guess that if the functionality of pure_callback is not to be modified, the documentation should at least explain this problem.

Personally, I'd prefer if pure_callback allowed me to pass in a signature for the vectorized parameter. This would not work for pytrees in the input or output, though, only for standard gufuncs.

Gattocrucco avatar Aug 21 '23 10:08 Gattocrucco

I think I've run into the same issue. In the below example, when vectorized=True it appears to be impossible for the callback to know whether it is supposed to return a 2x2 array or a 2-vector, since it sees the exact same input in both cases. Nor is it possible for the enclosing function to specify the return shape via additional arguments, because it does not see the desired return shape when transformed with vmap (unless there's some way to extract this that I don't know about):

def g(x, y):
    bsh = jnp.broadcast_shapes(x.shape, y.shape)
    sds = jax.ShapeDtypeStruct(shape=bsh, dtype=jnp.promote_types(x, y))
    return jax.pure_callback(np.add, sds, x, y, vectorized=True)

e = jnp.ones(2)
g(e, e)
vmap(g, (0, 0))(e, e)  # [2. 2.]; works
vmap(vmap(jnp.add, (0, None)), (None, 0))(e, e)  # [[2., 2.], [2., 2.]]; works
vmap(vmap(g, (0, None)), (None, 0))(e, e)  # error expected callback to return (2,2); actual (2,)

terhorst avatar Nov 02 '23 13:11 terhorst

JAX is a great library and i would like to thank the developers for their hard work on this awesome, clever, and powerful software!

Regarding the issue, the docs say that vectorized indicates is the callback can handle arrays with additional leading dimensions. And when it is True, then during the vmap transformation the callback will be called directly on inputs with leading batch dimensions. I read this as meaning that "if the callback supports built-in broadcasting over leading dims then vectorized can be set to True".

Below are the traced japxrs when we use jax.pure_callback(np.matmul, ..., vectorized=True) as an example, as mentioned in the docs.

Traced with the current batching impl this produces a direct call:

{ lambda ; a:f32[8,4,2,5] b:f32[8,5,2]. let
    c:f32[8,4,2,2] = pure_callback[
      callback=<function pure_callback.<locals>._flat_callback at 0x1452eed40>
      result_avals=(ShapedArray(float32[8,4,2,2]),)
      sharding=None
      vectorized=True
    ] a b
  in (c,) }

Here, the callback is invoked on (f32[8,4,2,5], f32[8,5,2]), which do not broadcast, because the inner vmap did not inject a unit dim. Thus the callback with vectorized=True fails to get the promised batching in the operands.

On the other hand, if we trace with a rule, based on built-in broadcasting (refer to the code at the end on this comment), then the jaxpr reads

{ lambda ; a:f32[8,4,2,5] b:f32[8,5,2]. let
    c:f32[8,1,5,2] = broadcast_in_dim[
      broadcast_dimensions=(0, 2, 3)
      shape=(8, 1, 5, 2)
    ] b
    d:f32[8,4,2,2] = pure_callback[
      callback=<function pure_callback.<locals>._flat_callback at 0x1452eede0>
      result_avals=(ShapedArray(float32[8,4,2,2]),)
      sharding=None
      vectorized=True
    ] a c
  in (d,) }

here the callback receives correctly batched arguments (f32[8,4,2,5], f32[8,1,5,2]).

Looking at broadcast_batcher and the traced jaxprs, it seems that the current callback batching logic unnecessarily skips broadcasting tracers with dim=None.


The following code was used to produce the jaxprs, and contains a provisional implementation of a batching rule:

import numpy as np
import jax
from jax import numpy as jp

# we will tinker with primitive defs and rules
from jax._src.interpreters import batching as bat
from jax._src import callback as cb, core

def jp_matmul(a, b, *, vectorized=True):
    """numpy matmul wrapped as a pure-callback for illustration"""
    # sanity check
    assert jp.ndim(a) > 1 and jp.ndim(b) > 1
    (*batch_a, n, ka), (*batch_b, kb, m) = a.shape, b.shape
    assert ka == kb

    # prepare the resulting shaped ary
    shape = *jp.broadcast_shapes(batch_a, batch_b), n, m
    aval = jax.ShapeDtypeStruct(shape, jp.result_type(a.dtype, b.dtype))

    # make the call
    return jax.pure_callback(np.matmul, aval, a, b, vectorized=vectorized)

# new rule that calls bat.broadcast_batcher, and fixups `results-aval`
def new_pure_callback_batching_rule(
    args,
    dims,
    *,
    callback,
    sharding,
    vectorized,
    result_avals,
):
    if not vectorized:
        return cb.pure_callback_batching_rule(
            args,
            dims,
            callback=callback,
            sharding=sharding,
            vectorized=vectorized,
            result_avals=result_avals,
        )

    # fixup `result_avals`
    axis_size = next(a.shape[0] for a, d in zip(args, dims)
                     if d is not bat.not_mapped)
    new_result_avals = tuple(
        core.unmapped_aval(axis_size, core.no_axis_name, 0, aval)
        for aval in result_avals
    )

    # vectorized=True suggests that the calback has built-in broadcasting
    return bat.broadcast_batcher(
        cb.pure_callback_p,
        args,
        dims,
        callback=callback,
        sharding=sharding,
        vectorized=vectorized,
        result_avals=new_result_avals,
    )

# create two batches of matrices
A = jp.r_[-1:+1:80j].reshape(8, 5, 2)
B = jp.r_[-1:+1:320j].reshape(8, 4, 2, 5)

# jointly batch over 8, and then on B's 4
# XXX the jp_matmul sees `f32[2, 5]` and `f32[5, 2]`
foo = jax.vmap(jax.vmap(jp_matmul, (0, None)), (0, 0))

# trace the pure callback using the current autobatch impl
assert bat.primitive_batchers[cb.pure_callback_p] is cb.pure_callback_batching_rule
jaxpr1 = jax.make_jaxpr(foo)(B, A)

# use the custom batcher
current_vmap_impl = bat.primitive_batchers[cb.pure_callback_p]
try:
    bat.primitive_batchers[cb.pure_callback_p] = new_pure_callback_batching_rule
    jaxpr2 = jax.make_jaxpr(foo)(B, A)

finally:
    bat.primitive_batchers[cb.pure_callback_p] = current_vmap_impl  # restore

ivannz avatar Apr 22 '24 18:04 ivannz

@patrick-kidger

ufunc behaviour cannot be the default, as other users may wish to use pure_callback for functions which are not ufuncs.

I'm not sure this makes sense. If the shape info is not promised to have have common pre-broadcasted batch dims on mapped input arrays, how is the callback supposed to know what shape to construct. The ufunc formalism is exactly for this problem. Please provide evidence for users that wish for non-ufunc behaviour or else I invoke Hitchens's razor.

Joshuaalbert avatar Sep 13 '24 12:09 Joshuaalbert