jax icon indicating copy to clipboard operation
jax copied to clipboard

strange slowing down when compute some gradients with complex128

Open y1xiaoc opened this issue 3 years ago • 3 comments

When I tried to do grad computation on complex128 matrix I noticed there are some cases the compiled function is very slow on GPU. The slow behavior only happens to the grad with specific arguments while for a similar other one the computation time is normal. I tried to rewrite the elementwise complex multiplication on my own and the problem is solved. An example function can be as following

def test_fun_scan(scale, mats, vec):
    scale = scale[:,None,None] + 0j
    cmats = scale * mats
    vec, _ = lax.scan(lambda v, m: (m @ v, None), vec+0j, cmats)
    return vec

scale = jnp.ones(11)
batch_mats = jnp.zeros((2000, 11, 10, 10))
vec = jnp.ones((10, 2))

and grad wrt scale is several times slower than wrt mats when I vmapped the function in mats. Changing the scan into python loop makes the problem much more severe, while rewriting the complex multiplication resolves it.

For a full example and benchmark, please see the following colab: https://colab.research.google.com/drive/1g8zmhzT1obIOiRMSiPk7E3o34QjUJWOl?usp=sharing

I'm assuming there might be something wrong with the compiler, and I also noticed sometimes the jnp.exp function also has some performance problem when doing grad calculation, but I didn't exact a example for it.

y1xiaoc avatar Sep 23 '21 02:09 y1xiaoc

Thanks for the report! The behavior does not arise on CPU, so I suspect this is an XLA:GPU issue. GPU hardware in general does not have good support for 64-bit operations, so less attention has been paid to optimizing this type of code execution. And indeed, the issue goes away when disabling X64.

I'm not sure the likelihood of this being fixed, but in general if you're sensitive to performance I'd suggest avoiding 64-bit computation on GPU (not because of lack of support in JAX or XLA, but because the hardware simply doesn't handle 64-bit computations efficiently).

jakevdp avatar Sep 25 '21 14:09 jakevdp

I wonder if it would be possible to isolate the complex128 multiplication issue with a simpler XLA program? If we could do that, we would have more likelihood of a useful bug report to XLA:GPU.

jakevdp avatar Sep 25 '21 14:09 jakevdp

Thank you for the reply! Unfortunately I need the x64 precision in my application. For now I managed to bypass it by using the custom complex multiplication as in the latter part of the example and was able to get a ~8x speedup.

I was not able to isolate the issue to a complex128 multiplication and the following scan or loop (or just a long list of commands?) seems to be necessary. Doing only a element multiplication and calculating its gradient work just fine.

y1xiaoc avatar Sep 25 '21 15:09 y1xiaoc

I think this problem is gone in recent jax versions. I will close the issue.

y1xiaoc avatar Nov 03 '23 19:11 y1xiaoc