iree
iree copied to clipboard
Slow reductions on CPU
What happened?
The following JAX program, run with JAX_PLATFORMS=iree, benchmarks reductions over various axes of an array vs numpy. My particular build of numpy is single-threaded.
import jax.numpy as jnp
import numpy as np
from itertools import chain, combinations
import timeit
def powerset(iterable):
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
x = np.random.randn(1500, 1500, 151).astype(np.float32)
y = jnp.asarray(x)
print("Numpy timings:")
for axes in powerset({0, 1, 2}):
n = 15
t = timeit.timeit(stmt='np.amax(x, axis=tuple(axes))', globals={'np': np, 'x':x, 'axes':axes}, number=n)
print(f"axes={axes} time: {t/n*1000}ms")
print("IREE timings:")
for axes in powerset({0, 1, 2}):
n = 15
np.asarray(jnp.amax(y, axis=tuple(axes)) # warmup
t = timeit.timeit(stmt='np.asarray(jnp.amax(x, axis=tuple(axes)))', globals={'np': np, 'jnp': jnp, 'x':x, 'axes':axes}, number=n)
print(f"axes={axes} time: {t/n*1000}ms")
The output on my 96-vCPU GCP VM is:
Numpy timings:
axes=() time: 246.87468299331763ms
axes=(0,) time: 170.56808766598502ms
axes=(1,) time: 212.64138147234917ms
axes=(2,) time: 268.60458532658714ms
axes=(0, 1) time: 214.8650777991861ms
axes=(0, 2) time: 225.28301280302307ms
axes=(1, 2) time: 127.62742745690048ms
axes=(0, 1, 2) time: 126.01195853203534ms
IREE timings:
axes=() time: 271.54760079768795ms
axes=(0,) time: 441.6122421932717ms
axes=(1,) time: 327.67158267088234ms
axes=(2,) time: 96.2033308732013ms
axes=(0, 1) time: 838.4074673211824ms
axes=(0, 2) time: 155.43748705337447ms
axes=(1, 2) time: 128.6482287881275ms
axes=(0, 1, 2) time: 1307.027867793416ms
i.e. IREE local-task seems significantly slower than NumPy for some axis choices, notably `[0], [1], [0, 1], and [0, 1, 2]'. The poor performance for the full reduction is particularly surprising ([0,1,2]).
The MHLO looks like this:
module @jit__lambda_.21 {
func.func public @main(%arg0: tensor<1500x1500x151xf32>) -> tensor<f32> {
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
%1 = mhlo.reduce(%arg0 init: %0) across dimensions = [0, 1, 2] : (tensor<1500x1500x151xf32>, tensor<f32>) -> tensor<f32>
reducer(%arg1: tensor<f32>, %arg2: tensor<f32>) {
%2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
mhlo.return %2 : tensor<f32>
}
return %1 : tensor<f32>
}
}
Steps to reproduce your issue
See above.
What component(s) does this issue relate to?
Compiler
Version information
git commit e3cbeac692
Additional context
No response