quax icon indicating copy to clipboard operation
quax copied to clipboard

`scan` primitive implementation missing

Open colehaus opened this issue 1 year ago • 4 comments

Looking at https://github.com/patrick-kidger/quax/blob/a9d875e323f8159bb230d9a8880981629ee4d7a5/quax/_core.py#L551, I assume scan is supposed to be implemented generically for all Quax objects. I took a quick try at an implementation like this:

@quax.register(lax.scan_p)
def _(
    *args: Union[quax.ArrayValue, ArrayLike],
    reverse: bool,
    length: int,
    jaxpr,
    num_consts: int,
    num_carry: int,
    linear,
    unroll: int = 1,
    _split_transpose: Optional[bool] = None,
):
    consts = args[:num_consts]
    init = args[num_consts : num_consts + num_carry]
    xs = args[num_consts + num_carry :]

    quax_f = quax.quaxify(jax.core.jaxpr_as_fun(jaxpr))
    quax_jaxpr = jax.make_jaxpr(quax_f)(*consts, *init, *xs)

    const_leaves, _ = jtu.tree_flatten(consts)
    init_leaves, init_treedef = jtu.tree_flatten(init)
    xs_leaves, _ = jtu.tree_flatten(xs)

    out_flat = lax.scan_p.bind(
        *const_leaves,
        *init_leaves,
        *xs_leaves,
        reverse=reverse,
        length=length,
        jaxpr=quax_jaxpr,
        num_consts=num_consts,
        num_carry=num_carry,
        linear=linear,
        unroll=unroll,
        _split_transpose=_split_transpose,
    )

    # _initial_style_jaxpr(quax_f, , , "scan")
    carry_nvals = len(init_leaves)
    carry, ys = out_flat[:carry_nvals], out_flat[carry_nvals:]

    carry_out = jtu.tree_unflatten(init_treedef, carry)

    return carry_out, None

But there are at least two problems:

  • When I actually use it, I get:
File /usr/local/lib/python3.11/dist-packages/quax/_core.py:194, in <listcomp>(.0)
    [192](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:192)         out = method(*values, **params)
    [193](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:193) if primitive.multiple_results:
--> [194](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:194)     return [_QuaxTracer(self, _wrap_if_array(x)) for x in out]  # pyright: ignore
    [195](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:195) else:
    [196](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:196)     return _QuaxTracer(self, _wrap_if_array(out))

File /usr/local/lib/python3.11/dist-packages/quax/_core.py:84, in _QuaxTracer.__init__(self, trace, value)
     [83](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:83) def __init__(self, trace: "_QuaxTrace", value: "Value"):
---> [84](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:84)     assert _is_value(value)
     [85](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:85)     self._trace = trace
     [86](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:86)     self.value = value
  • I haven't implemented the part to extract the out_treedef for the second element of the return value. I think this would be possible by following the same strategy as what's in https://github.com/google/jax/blob/ebc6c1815297c79bc1c9c907aaf858d70caef5e6/jax/_src/lax/control_flow/loops.py#L123, but I'm not sure if there's a simpler way.

Having scan available is pretty handy for use with the scan over layers technique.

colehaus avatar Sep 04 '24 05:09 colehaus

Very happy to take a PR on this. But I'm not immediately sure what the issue is here though I'm afraid!

patrick-kidger avatar Sep 04 '24 06:09 patrick-kidger

Something like this seems to roughly work:

@quax.register(lax.scan_p)
def _(
    *args: Union[quax.ArrayValue, ArrayLike],
    reverse: bool,
    length: int,
    jaxpr,
    num_consts: int,
    num_carry: int,
    linear,
    unroll: int = 1,
    _split_transpose: Optional[bool] = None,
):
    const = args[:num_consts]
    init = args[num_consts : num_consts + num_carry]
    xs = args[num_consts + num_carry :]

    const_flat, _ = jtu.tree_flatten(const)
    const_avals = tuple(safe_map(_abstractify, const_flat))

    xs_flat, _ = jtu.tree_flatten(xs)
    xs_avals = tuple(safe_map(_abstractify, xs_flat))
    x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]

    quax_f = quax.quaxify(jax.core.jaxpr_as_fun(jaxpr))

    init_flat, init_treedef = jtu.tree_flatten(init)
    carry_avals = tuple(safe_map(_abstractify, init_flat))

    in_flat, in_treedef = jtu.tree_flatten(const + init + xs)
    jaxpr, consts, out_treedef = _initial_style_jaxpr(quax_f, in_treedef, (*const_avals, *carry_avals, *x_avals), "scan")

    out_flat = lax.scan_p.bind(
        *consts,
        *in_flat,
        reverse=reverse,
        length=length,
        jaxpr=jaxpr,
        num_consts=num_consts,
        num_carry=num_carry,
        linear=(False,) * (len(consts) + len(in_flat)),
        unroll=unroll,
        _split_transpose=_split_transpose,
    )
    carry_out = jtu.tree_unflatten(init_treedef, out_flat[: init_treedef.num_leaves])
    num_extra_outs = out_treedef.num_leaves - init_treedef.num_leaves
    flat_structure = jtu.tree_structure((0,) * num_extra_outs)
    extra_out = jtu.tree_unflatten(flat_structure, out_flat[init_treedef.num_leaves :])
    return carry_out[0] if len(init) == 1 else carry_out, extra_out[0] if num_extra_outs == 1 else extra_out

Two problems:

  • If the first output value is a nested pytree and not a simple array, we get an error:
File /usr/local/lib/python3.11/dist-packages/quax/_core.py:194, in <listcomp>(.0)
   [192](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:192)         out = method(*values, **params)
   [193](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:193) if primitive.multiple_results:
--> [194](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:194)     return [_QuaxTracer(self, _wrap_if_array(x)) for x in out]  # pyright: ignore
   [195](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:195) else:
   [196](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:196)     return _QuaxTracer(self, _wrap_if_array(out))

File /usr/local/lib/python3.11/dist-packages/quax/_core.py:84, in _QuaxTracer.__init__(self, trace, value)
    [83](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:83) def __init__(self, trace: "_QuaxTrace", value: "Value"):
---> [84](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:84)     assert _is_value(value)
    [85](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:85)     self._trace = trace
    [86](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:86)     self.value = value

I think this might be a mistake in Quax and the wrapping and asserting need to be tree-mapped over x?

  • I don't think we have any easy way to find out what the proper PytreeDef is for the full output. We can infer the PytreeDef for the carry part of the output from the input. But the extra output (b) structure comes from the function itself while the version of the function reconstructed from jaxpr_as_fun doesn't have this structure and only produces flat leaves. This rough implementation doesn't respect the structure passed in from the user and always return the extra output as a flat tuple with the corresponding number of leaves.

colehaus avatar Sep 04 '24 22:09 colehaus

And here's an implementation of remat that seems to work (important for use with scan: https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes):

@quax.register(jax._src.ad_checkpoint.remat_p)
def _(*args, jaxpr, prevent_cse, differentiated, policy):
    del prevent_cse, differentiated, policy
    # `jaxpr_as_fun` expects a closed jaxpr. `scan_p` already gets one but `remat_p` doesn't.
    closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
    quax_f = quax.quaxify(jax.core.jaxpr_as_fun(closed_jaxpr))
    in_flat, in_treedef = jtu.tree_flatten(args)
    in_avals = tuple(safe_map(_abstractify, in_flat))
    quax_jaxpr, consts, out_tree = _initial_style_jaxpr(quax_f, in_treedef, in_avals, "remat")
    out_flat = jax.core.eval_jaxpr(quax_jaxpr.jaxpr, (), *consts, *in_flat)
    return jtu.tree_unflatten(out_tree, out_flat)

colehaus avatar Sep 05 '24 00:09 colehaus

I think this might be a mistake in Quax and the wrapping and asserting need to be tree-mapped over x?

I don't think so. Primitive binds can produce either a single array or a sequence of arrays. By the time we're in the JAX internals like this then pytrees have largely disappeared.

Anyway, these broadly all look good to me! I'd be happy to take PRs on these, including tests for the kinds of edge cases you're bumping into.

patrick-kidger avatar Sep 05 '24 06:09 patrick-kidger