jax icon indicating copy to clipboard operation
jax copied to clipboard

Add control flow operators for iterated functions

Open carlosgmartin opened this issue 3 years ago • 16 comments

Add the following control flow operators for iterated functions:

# iterate : ∀a. ℕ → (a → a) → (a → a)
def iterate(n, f, x):
    for _ in range(n):
        x = f(x)
    return x

# orbit : ∀a. ℕ → (a → a) → (a → [a])
def orbit(n, f, x):
    xs = [x]
    for _ in range(n):
        x = f(x)
        xs.append(x)
    return xs

iterate and fori_loop are mutually definable, but the former has a simpler signature and semantics that is more common in my experience, and is arguably more "basic" or "fundamental".

carlosgmartin avatar Mar 18 '22 18:03 carlosgmartin

Interesting idea - if I understand your suggestion correctly, I think the implementations would look like this:

def iterate(n, f, x):
  return lax.scan(lambda x, _: (f(x), x), x, None, length=n)[0]

def orbit(n, f, x):
  return lax.scan(lambda x, _: (f(x), x), x, None, length=n + 1)[1]

Is that what you have in mind?

jakevdp avatar Mar 18 '22 19:03 jakevdp

For iterate we may want lax.scan(lambda x, _: (f(x), None), x, None, length=n)[0], i.e. don't have an extensive output.

mattjj avatar Mar 18 '22 21:03 mattjj

@jakevdp @mattjj Looks right:

from jax.numpy import array
from jax.lax import scan

def iterate_1(n, f, x):
    for _ in range(n):
        x = f(x)
    return x

def iterate_2(n, f, x):
  return scan(lambda x, _: (f(x), x), x, None, length=n)[0]

def iterate_3(n, f, x):
    return scan(lambda x, _: (f(x), None), x, None, length=n)[0]

def orbit_1(n, f, x):
    xs = [x]
    for _ in range(n):
        x = f(x)
        xs.append(x)
    return array(xs)

def orbit_2(n, f, x):
    return scan(lambda x, _: (f(x), x), x, None, length=n + 1)[1]

n = 10
f = lambda x: 2 * x + 1
x = 3
print(iterate_1(n, f, x)) # 4095
print(iterate_2(n, f, x)) # 4095
print(iterate_3(n, f, x)) # 4095
print(orbit_1(n, f, x)) # [   3    7   15   31   63  127  255  511 1023 2047 4095]
print(orbit_2(n, f, x)) # [   3    7   15   31   63  127  255  511 1023 2047 4095]

carlosgmartin avatar Mar 19 '22 02:03 carlosgmartin

Out of curiosity, can this function be implemented in terms of lax primitives?

from jax.numpy import array

# orbit_while : ∀a. (a → 𝔹) → (a → a) → (a → [a])
def orbit_while(p, f, x):
    xs = []
    while p(x):
        xs.append(x)
        x = f(x)
    return array(xs) # contains only elements satisfying p

p = lambda x: x < 2047
f = lambda x: 2 * x + 1
x = 3

print(orbit_while(p, f, x)) # [   3    7   15   31   63  127  255  511 1023]

carlosgmartin avatar Mar 19 '22 03:03 carlosgmartin

A further question for discussion:are these new API's necessary if they can be implemented via a single call to scan? One might argue that iterate and orbit already exist in JAX, they're just called scan. What do you think?

jakevdp avatar Mar 19 '22 03:03 jakevdp

@jakevdp I think it's convenient to have these useful helper functions in the library to make it easier for users to adopt JAX, and to save them the trouble of figuring out how to implement them in terms of scan. After all, I don't think it'd be good to take away all existing control flow functions that can be implemented in terms of scan.

carlosgmartin avatar Mar 19 '22 03:03 carlosgmartin

Sure, but adding new APIs does not come without maintenance costs. It's true that while_loop and fori_loop can be implemented in terms of scan, but their implementations are far more involved than a single line.

As this is the first time I'm aware of receiving such a request, I would lean toward not including them in the package API, but I'm happy to change my mind if there are compelling reasons to do so.

jakevdp avatar Mar 21 '22 15:03 jakevdp

I'd much rather have a bounded while loop.

soraros avatar Mar 21 '22 19:03 soraros

@jakevdp I think this works as a single-line implementation of fori_loop:

from jax.lax import fori_loop, scan

def fori_loop_scan(a, b, f, x):
  return scan(lambda ix, _: ((ix[0] + 1, f(*ix)), None), (a, x), None, b - a)[0][1]

a = 2
b = 8
f = lambda i, x: 3 + 2 * x + 4 * i + i * x
x = 10

y1 = fori_loop(a, b, f, x)
y2 = fori_loop_scan(a, b, f, x)
print(y1) # 827986
print(y2) # 827986

Out of curiosity, what's the implementation of while_loop in terms of scan?

Also, is it possible to implement orbit_while in terms of lax primitives?

@soraros What do you mean?

carlosgmartin avatar Mar 21 '22 23:03 carlosgmartin

while_loop doesn't lower to scan, actually, since scan requires a static number of iterations.

jakevdp avatar Mar 22 '22 00:03 jakevdp

And orbit_while is not currently possible to express in JIT-compatible JAX, because it returns an array of dynamic length.

jakevdp avatar Mar 22 '22 00:03 jakevdp

@jakevdp That's what I suspected. Thanks.

carlosgmartin avatar Mar 22 '22 01:03 carlosgmartin

@soraros By a bounded while loop, do you mean something like this?

import jax

def bounded_while_loop(p, f, x, n):
    def g(i, x):
        return jax.lax.cond(p(x), f, lambda x: x, x)
    return jax.lax.fori_loop(0, n, g, x)

def p(x):
    return x < 10

def f(x):
    return x + 1

x = 0
print(jax.lax.while_loop(p, f, x)) # 10
print(bounded_while_loop(p, f, x, 100)) # 10
print(bounded_while_loop(p, f, x, 5)) # 5

carlosgmartin avatar Oct 17 '22 22:10 carlosgmartin

@jakevdp What do you think of letting xs=None by default in scan? That pattern seems to occur often.

carlosgmartin avatar Nov 11 '22 06:11 carlosgmartin

I think that could be an improvement – I'd want to hear opinions from other folks on the team

jakevdp avatar Nov 11 '22 12:11 jakevdp

I'd also like to suggest the following functions:

def foldl(f: Callable, h, xs, length: Optional[int] = None):
    '''
    http://zvon.org/other/haskell/Outputprelude/foldl_f.html
    (a → b → a) → a → [b] → a
    Arguments:
    `f`: Function.
    `h`: Initial value.
    `x`: Inputs.
    Returns:
    Final value.
    '''
    def g(h, x):
        return f(h, x), None
    h, _ = lax.scan(g, h, xs, length)
    return h

def scanl(f: Callable, h, xs, length: Optional[int] = None):
    '''
    http://zvon.org/other/haskell/Outputprelude/scanl_f.html
    (a → b → a) → a → [b] → [a]
    Arguments:
    `f`: Function.
    `h`: Initial value.
    `x`: Inputs.
    Returns:
    Intermediate values.
    '''
    def g(h, x):
        return f(h, x), h
    h, hs = lax.scan(g, h, xs, length)
    return tree_map(lambda h, hs: jnp.concatenate((hs, h[None])), h, hs)

These are very general and useful abstractions for processing sequences.

I think it's a good idea to add common control-flow constructs like these to the standard library. This saves users the trouble of having to figure out how to implement them in terms of scan, which has a more complex interface. (The latter can be especially inconvenient for new users not accustomed to the purely-functional approach.)

carlosgmartin avatar Feb 23 '23 04:02 carlosgmartin