equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Speed of function that was jitted with eqx.filter_jit is slower than Flax's jit equivalent on CPU

Open Artur-Galstyan opened this issue 1 year ago • 1 comments

Hi there,

I was just comparing the performances of different ML libraries (basically to see if JAX is faster than PyTorch or Tensorflow - the consensus is that JAX is faster but I want numbers) and I noticed that the function generated by eqx.filter_jit was quite slow? It's easier to show in code. So, this is the very simple benchmark:

the pip installs to reproduce this pip install equinox tqdm polars
The timeit util, not the main point of the question, but here for completeness
import time
from functools import wraps

import polars as pl
from tqdm import tqdm


def timeit(
    func, print_time: bool = False, name: str | None = None, n_repeats: int = 10
):
    @wraps(func)
    def wrapper(*args, **kwargs):
        df = []
        result = func(*args, **kwargs)
        for _ in tqdm(range(n_repeats)):
            start = time.time()
            func(*args, **kwargs)
            end = time.time()
            df.append(end - start)
        if print_time:
            print(
                f"{func.__name__ if name is None else name} took {sum(df) / n_repeats:.2f} seconds on average"
            )

        return result, pl.DataFrame(
            {
                "function": func.__name__ if name is None else name,
                "avg_time": sum(df) / n_repeats,
                "n_repeats": n_repeats,
                "min": min(df),
                "max": max(df),
            }
        )

    return wrapper

class EquinoxLinearModule(eqx.Module):
    linear: eqx.nn.Linear

    def __init__(self, in_features: int, out_features: int):
        self.linear = eqx.nn.Linear(in_features, out_features, key=jax.random.key(0))

    def __call__(self, x):
        return self.linear(x)


def benchmark_equinox_linear(min_power: int = 1, max_power: int = 4):
    powers = [i for i in range(min_power, max_power + 1)]
    res = []
    for power in powers:
        lin = EquinoxLinearModule(10**power, 10**power)
        lin_jit = eqx.filter_jit(lin)  # very slow
        # lin_jit = eqx.filter_jit(lin.__call__)  # equally slow
        # lin_jit = jax.jit(lin)  # error
        # lin_jit = jax.jit(lin.__call__)  # extremely fast
        x = jax.random.normal(jax.random.PRNGKey(0), (10**power, 10**power))
        func = lambda: lin_jit(x)
        _, df = timeit(func, name=f"eqx-lin-{10}^{power}")()
        res.append(df)
    res = pl.concat([df for df in res], how="vertical")
    return res

eqx.filter_jit gives this result:

┌──────────────┬──────────┬───────────┬──────────┬──────────┐
│ function     ┆ avg_time ┆ n_repeats ┆ min      ┆ max      │
│ ---          ┆ ---      ┆ ---       ┆ ---      ┆ ---      │
│ str          ┆ f64      ┆ i64       ┆ f64      ┆ f64      │
╞══════════════╪══════════╪═══════════╪══════════╪══════════╡
│ eqx-lin-10^4 ┆ 2.937645 ┆ 10        ┆ 2.888101 ┆ 3.034286 │
└──────────────┴──────────┴───────────┴──────────┴──────────┘

whereas the jax.jit(lin.__call__) and the Flax version give these results respectively (which are besically equivalent):

┌──────────────┬──────────┬───────────┬──────────┬──────────┐
│ function     ┆ avg_time ┆ n_repeats ┆ min      ┆ max      │
│ ---          ┆ ---      ┆ ---       ┆ ---      ┆ ---      │
│ str          ┆ f64      ┆ i64       ┆ f64      ┆ f64      │
╞══════════════╪══════════╪═══════════╪══════════╪══════════╡
│ eqx-lin-10^4 ┆ 0.000013 ┆ 10        ┆ 0.000007 ┆ 0.000039 │
└──────────────┴──────────┴───────────┴──────────┴──────────┘
┌───────────────┬──────────┬───────────┬──────────┬──────────┐
│ function      ┆ avg_time ┆ n_repeats ┆ min      ┆ max      │
│ ---           ┆ ---      ┆ ---       ┆ ---      ┆ ---      │
│ str           ┆ f64      ┆ i64       ┆ f64      ┆ f64      │
╞═══════════════╪══════════╪═══════════╪══════════╪══════════╡
│ flax-lin-10^4 ┆ 0.000068 ┆ 10        ┆ 0.000028 ┆ 0.000394 │
└───────────────┴──────────┴───────────┴──────────┴──────────┘

So this left me a bit confused, because I always thought that eqx.filter_jit was just a thin wrapper around jax.jit but that wouldn't explain the large difference. My tests were performed on a Macbook M1 on the CPU.

Artur-Galstyan avatar Nov 17 '24 20:11 Artur-Galstyan

You've forgotten to call .block_until_ready(). Equinox will actually call this for you automatically:

https://github.com/patrick-kidger/equinox/blob/15a800dd0ab1fc91b033c9305a5fe2f7bf2aecae/equinox/_jit.py#L248

But the others don't do this by default.

Equinox does this so that runtime errors are correctly surfaced during the JIT'd call, and not at some later point (or possibly not at all if the program stops before then).

I tested your benchmark with this addition, using both jax.jit and eqx.filter_jit, and get comparable timings.

patrick-kidger avatar Nov 17 '24 20:11 patrick-kidger