aesara icon indicating copy to clipboard operation
aesara copied to clipboard

Target JAX's intermediate representation with the JAX linker

Open rlouf opened this issue 3 years ago • 6 comments

The JAX linker currently target the library's numpy-like high level API. For instance, the Dot Op is translated using jax.numpy.dot:

@jax_funcify.register(Dot)
def jax_funcify_Dot(op, **kwargs):
    def dot(x, y):
        return jnp.dot(x, y)

    return dot

However JAX is a symbolic library (albeit a limited one) and has its own intermediate representation. When the user calls a function written with jax.numpy primitive for the first time, JAX traces the function and converts it to a Jaxpr that is then processed by XLA. Therefore, when one transpiles their Aeasara code to JAX and runs the resulting code is traced. This is completely unnecessary since all the information needed to build JAX's intermediate representation is already contained in the Aesara graph.

We could therefore, in theory, translate Aesara's Ops directly to JAX's intermediate representation. We would not only improve runtime performance (gain to be estimated), but also have more freedom for the transpilation since we won't be limited to JAX's high level API.

Proof of concept

Before opening a PR, I will try in the comments of this issue to translate the following Aesara graph:

import aesara
import aesara.tensor as at

a = at.vector()
b = at.vector()

c = a + b

aesara.dprint(c)
# : Elemwise{add,no_inplace} [id A]
# :  |<TensorType(float64, (None,))> [id B]
# :  |<TensorType(float64, (None,))> [id C]

To its JAX equivalent:

import jax.numpy as np
from jax import lax
from jax import make_jaxpr

def add_fn(a, b):
    return lax.add(a, b)

print(make_jaxpr(add_fn)(np.array([1., 1.]), np.array([1., 1.])))
# { lambda ; a:f32[2] b:f32[2]. let c:f32[2] = add a b in (c,) }

This example is simple, but raises the question of how types and shapes are handled in JAX's IR. In particular, I am currently not sure that JAX can handle arrays of unknown (but fixed) length. If it cannot we can imagine a "delayed transpilation" where Aesara would generate JAX's IR when the function is called with arguments.

rlouf avatar Jun 24 '22 15:06 rlouf

Here are a couple of issues for which the idea was considered as a possible solution:

  • https://github.com/aesara-devs/aesara/issues/43#issuecomment-708029409
  • https://github.com/aesara-devs/aesara/issues/68#issue-713222807

brandonwillard avatar Jun 24 '22 16:06 brandonwillard

Here are a couple of issues for which the idea was considered as a possible solution:

  • https://github.com/aesara-devs/aesara/issues/43#issuecomment-708029409
  • https://github.com/aesara-devs/aesara/issues/68#issue-713222807

Can we actually circumvent these limitations this way?

In addition, would Aesara still be able to generate JAX code that we can pass around to other libraries (e.g., BlackJax which will call grad/JIT on a user defined JAX function)?

ricardoV94 avatar Jun 26 '22 05:06 ricardoV94

In addition, would Aesara still be able to generate JAX code that we can pass around to other libraries (e.g., BlackJax which will call grad/JIT on a user defined JAX function)?

This is a legitimate concern, and something we should figure out before investing too much time in it.

rlouf avatar Jun 26 '22 07:06 rlouf

To follow up on the previous discussion. We were considering the following function:

from jax import lax
from jax import make_jaxpr
import jax.numpy as jnp

def add_fn(a, b):
    return lax.add(a, b)

x = jnp.array([1., 1.])
y = jnp.array([2., 3.])
add_fn(x, y)
# [3., 4.]

JAX traces the user's functions to translate them to (Closed) JAXPRs, and those contain information about the shape and type of inputs:

from jax import make_jaxpr

add_jaxpr = make_jaxpr(add_fn)(x, y)
add_jaxpr
# { lambda ; a:f32[2] b:f32[2]. let c:f32[2] = add a b in (c,) }

add_1d_jaxpr = make_jaxpr(add_fn)(1., 1.)
add_1d_jaxpr
# { lambda ; a:f32[] b:f32[]. let c:f32[] = add a b in (c,) }

the JAXPRs are objects:

add_1d_jaxpr.jaxpr.eqns
# [a:f32[] = add b c]
add_1d_jaxpr.jaxpr.invars
# [a, b]
add_1d_jaxpr.jaxpr.outvars
# [c]

More interestingly, we can get an object that behaves like a function from ClosedJaxprs using what devs call an interpreter:

from jax.core import jaxpr_as_fun

add_1d = jaxpr_as_fun(add_1d_jaxpr)
add_1d
# functools.partial(<function jaxpr_as_fun at 0x7f115d9ece50>, { lambda ; a:f32[] b:f32[]. let c:f32[] = add a b in (c,) })

This needs to be double checked but it seems that no tracing is happening anymore; I can for instance pass the x and y arrays to the function build form the JAXPR obtained with tracing with scalars:

add_1d(x, y)
# [DeviceArray([3., 4.], dtype=float32)]

as explained in the internals documentation the interpreter itself is tracebale so we can JIT-compile this function:

import jax

jitted_add1 = jax.jit(add_1d)
make_jaxpr(jitted_add1)(1., 1.)
# { lambda ; a:f32[] b:f32[]. let
#     c:f32[] = xla_call[
#       call_jaxpr={ lambda ; d:f32[] e:f32[]. let f:f32[] = add d e in (f,) }
#       name=<unnamed wrapped function>
#     ] a b
#   in (c,) }
jitted_add1(1., 2.)
# [DeviceArray(3., dtype=float32, weak_type=True)]

It feels safe to target Jaxprs for now. The next step is to build the function add_fn by building the ClosedJaxp manually (i.e. not by tracing a python function). Then we will try to understand what happens when jax.jit traces evaluated Jaxprs.

Unrelated note

We should be able to determine the largest jit-able (sub)set of the code doing static analysis of the corresponding aesara graph. jit obeys to very simple rules and those can be checked at compile time. This may be an appreciated feature and potentially allow us to transpile code that has tensors of varying shapes, for instance.

It may still be possible to jit completely functions using e.g. jax.numpy.reshape, but we may need to implement our own jitting function (aesara.link.jax.jit). We need to explore XLA's primitives to see what the true limitations are here (and not those baked in JAX). We can use JAX merely as XLA python bindings and lower Jaxprs we created to functions.

As far as I understand the motivation behind the omnistaging change in JAX (https://github.com/google/jax/pull/3370), the issues it tries to solve can be circumvented when one has a symbolic graph it can analyze.

This file is a good starting point for the translations from Ops to XLA. I see mentions to MLIR in this file; if XLA can interpret MLIR we may want to directly target MLIR. There is a roadmap, but hard to know whether this is going to be done and when; if JAX starts lowering to MLIR there's a good chance this will happen?

rlouf avatar Jul 04 '22 20:07 rlouf

Cool. I assume other transformations like grad and vmap, can also be performed in the same way you could do jit, after calling jaxpr_as_fun?

ricardoV94 avatar Jul 09 '22 10:07 ricardoV94

~~Yes.~~

jax.grad requires tracing the function to build a "new graph" so it will not be possible to pass as an argument a function that is built this way. It is a minor inconvenience as Aesara can compute the gradients.

However, jax.jit, jax.vmap and jax.pmap (and the loops) would work with these functions.

rlouf avatar Jul 09 '22 12:07 rlouf

It is clear now that by targeting JAX's IR directly we would still be able to use jax.jit, jax.vmap on the compiled function, but will not be able to apply transformations like jax.grad. There is no free lunch.

What we do from here depends on the goals we set for the transpilation: if it's compatibility with the JAX ecosystem then the approach that the dispatcher currently takes is the most appropriate. If we want to target XLA while avoiding JAX's self-imposed limitations (aka build a JAX replacement of sort) then we might as well go all the way and target XLA's IR directly and use jaxlib as a bridge.

I believe that short term we should aim for compatibility with the broader JAX ecosystem. It is fairly simple, allows Aesara to piggyback on a much broader ecosystem, and we all know the size of the ecosystem is critical when it comes to adoption. We can however address some of the issues that motivated this thread by working Aesara side: for instance by making sure that shapes that are known to be constant at compile time are indeed set to a constant value before compiling. When it comes to known limitations of JAX like dynamic shapes we can fail gracefully and explain that this is due to a limitation on JAX's side. For things that JAX traces out like assert statements, I would simply warn the user it has been removed because of a limitation on JAX's side. Users still get the many benefits of Aesara like its rewrite system, while being able to use their favorite library (hopefully they will eventually see the interest in porting said library to Aesara).

Nevertheless, XLA remains an interesting target in itself for GPU and TPU. I think it is worth diving into the XLA documentation directly and figure out what we may gain from bypassing JAX altogether.

rlouf avatar Sep 13 '22 15:09 rlouf