jax-experimental icon indicating copy to clipboard operation
jax-experimental copied to clipboard

BUG: jnp.cumsum(np.arange(2**14)) gives segment fault.

Open AlexanderMath opened this issue 9 months ago • 2 comments

Description

Reproducer

import numpy as np 
import jax 
import jax.numpy as jnp 

# works 
#range = np.arange(2**13)
#print(np.cumsum(range))
#print(jax.jit(jnp.cumsum, backend="ipu")(range)) 

# gives segment fault
range = np.arange(2**14)
print(np.cumsum(range))
print(jax.jit(jnp.cumsum, backend="ipu")(range)) 

Output

$ python  reproduce.py 
[        0         1         3 ... 134176771 134193153 134209536]
Segmentation fault (core dumped)

Note:

  • initially suspected int32 overflow but 134M << 2**31

Meta comment. The reproducer took 2 hours to make because jnp.cumsum was used inside ~400 lines of code, and I wrongly assumed jnp.cumsum was unlikely to cause segment fault compared to: tesseleate-ipu, C code, usage of uint in C code, poplar simulation of uint64 in C code, passing from python to C code, index computations, ... . Would it be a lot of work to add automated testing on these basic (np, jnp) functions?

What jax/jaxlib version are you using?

0.3.16

Which accelerator(s) are you using?

IPU MK2

Additional System Info

No response

AlexanderMath avatar Sep 10 '23 13:09 AlexanderMath

Found same issue with jax.cumprod when trying to use log-tricks as temporary solution.

AlexanderMath avatar Sep 10 '23 13:09 AlexanderMath

Here's a hacky temporary solution. It uses matrix multiplication to compute jnp.cumsum of 2**7 chunks, and then subsequently adds the correct offsets. Use with caution. >90% of time is spent adding the subsequent offsets.

import jax 
import jax.numpy as jnp 

def matmul_cumsum_jax(arr):
    return jnp.tril(jnp.ones((len(arr), len(arr)))) @ arr 

def cumsum_jax(arr):
    chunk_size = 2**7 
    original_shape = arr.shape 
    padding = chunk_size - (len(arr) % chunk_size) if len(arr) % chunk_size != 0 else 0
    arr = jnp.pad(arr, (0, padding))  
    num_chunks = -(-len(arr) // chunk_size) 
    chunks = arr.T.reshape(num_chunks, chunk_size) 
    chunks = jax.vmap(matmul_cumsum_jax)(chunks)
    offset = 0
    offsets = [offset]
    for i, chunk in enumerate(chunks):
        offset += chunk[-1]
        offsets.append(offset)
    chunks = jax.vmap(jax.lax.add, in_axes=(0,0))(chunks, jnp.array(offsets[:-1]))
    return jnp.concatenate(chunks).reshape(-1)[:original_shape[0]]

arange = np.arange(2**14)
arange = np.concatenate((np.zeros(1), np.diff(arange))).astype(np.int32)
true_indxs = np.cumsum(arange)
us_indxs = np.asarray(jax.jit(cumsum_jax, backend="ipu")(arange)).astype(np.int32)
print(true_indxs[::127])
print(us_indxs[::127])
print(np.max(np.abs(true_indxs - us_indxs)))
print(np.all(true_indxs==us_indxs))

image

AlexanderMath avatar Sep 10 '23 14:09 AlexanderMath