iree icon indicating copy to clipboard operation
iree copied to clipboard

Slow reductions on CPU

Open hawkinsp opened this issue 3 years ago • 0 comments

What happened?

The following JAX program, run with JAX_PLATFORMS=iree, benchmarks reductions over various axes of an array vs numpy. My particular build of numpy is single-threaded.

import jax.numpy as jnp
import numpy as np
from itertools import chain, combinations
import timeit

def powerset(iterable):
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))


x = np.random.randn(1500, 1500, 151).astype(np.float32)
y = jnp.asarray(x)


print("Numpy timings:")
for axes in powerset({0, 1, 2}):
    n = 15
    t = timeit.timeit(stmt='np.amax(x, axis=tuple(axes))', globals={'np': np, 'x':x, 'axes':axes}, number=n)
    print(f"axes={axes} time: {t/n*1000}ms")

print("IREE timings:")
for axes in powerset({0, 1, 2}):
    n = 15
    np.asarray(jnp.amax(y, axis=tuple(axes))  # warmup
    t = timeit.timeit(stmt='np.asarray(jnp.amax(x, axis=tuple(axes)))', globals={'np': np, 'jnp': jnp, 'x':x, 'axes':axes}, number=n)
    print(f"axes={axes} time: {t/n*1000}ms")

The output on my 96-vCPU GCP VM is:

Numpy timings:
axes=() time: 246.87468299331763ms
axes=(0,) time: 170.56808766598502ms
axes=(1,) time: 212.64138147234917ms
axes=(2,) time: 268.60458532658714ms
axes=(0, 1) time: 214.8650777991861ms
axes=(0, 2) time: 225.28301280302307ms
axes=(1, 2) time: 127.62742745690048ms
axes=(0, 1, 2) time: 126.01195853203534ms
IREE timings:
axes=() time: 271.54760079768795ms
axes=(0,) time: 441.6122421932717ms
axes=(1,) time: 327.67158267088234ms
axes=(2,) time: 96.2033308732013ms
axes=(0, 1) time: 838.4074673211824ms
axes=(0, 2) time: 155.43748705337447ms
axes=(1, 2) time: 128.6482287881275ms
axes=(0, 1, 2) time: 1307.027867793416ms

i.e. IREE local-task seems significantly slower than NumPy for some axis choices, notably `[0], [1], [0, 1], and [0, 1, 2]'. The poor performance for the full reduction is particularly surprising ([0,1,2]).

The MHLO looks like this:

module @jit__lambda_.21 {
  func.func public @main(%arg0: tensor<1500x1500x151xf32>) -> tensor<f32> {
    %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
    %1 = mhlo.reduce(%arg0 init: %0) across dimensions = [0, 1, 2] : (tensor<1500x1500x151xf32>, tensor<f32>) -> tensor<f32>
     reducer(%arg1: tensor<f32>, %arg2: tensor<f32>)  {
      %2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
      mhlo.return %2 : tensor<f32>
    }
    return %1 : tensor<f32>
  }
}

Steps to reproduce your issue

See above.

What component(s) does this issue relate to?

Compiler

Version information

git commit e3cbeac692

Additional context

No response

hawkinsp avatar Sep 09 '22 15:09 hawkinsp