jax
jax copied to clipboard
strange slowing down when compute some gradients with complex128
When I tried to do grad computation on complex128 matrix I noticed there are some cases the compiled function is very slow on GPU. The slow behavior only happens to the grad with specific arguments while for a similar other one the computation time is normal. I tried to rewrite the elementwise complex multiplication on my own and the problem is solved. An example function can be as following
def test_fun_scan(scale, mats, vec):
scale = scale[:,None,None] + 0j
cmats = scale * mats
vec, _ = lax.scan(lambda v, m: (m @ v, None), vec+0j, cmats)
return vec
scale = jnp.ones(11)
batch_mats = jnp.zeros((2000, 11, 10, 10))
vec = jnp.ones((10, 2))
and grad wrt scale
is several times slower than wrt mats
when I vmapped the function in mats. Changing the scan into python loop makes the problem much more severe, while rewriting the complex multiplication resolves it.
For a full example and benchmark, please see the following colab: https://colab.research.google.com/drive/1g8zmhzT1obIOiRMSiPk7E3o34QjUJWOl?usp=sharing
I'm assuming there might be something wrong with the compiler, and I also noticed sometimes the jnp.exp
function also has some performance problem when doing grad calculation, but I didn't exact a example for it.
Thanks for the report! The behavior does not arise on CPU, so I suspect this is an XLA:GPU issue. GPU hardware in general does not have good support for 64-bit operations, so less attention has been paid to optimizing this type of code execution. And indeed, the issue goes away when disabling X64.
I'm not sure the likelihood of this being fixed, but in general if you're sensitive to performance I'd suggest avoiding 64-bit computation on GPU (not because of lack of support in JAX or XLA, but because the hardware simply doesn't handle 64-bit computations efficiently).
I wonder if it would be possible to isolate the complex128 multiplication issue with a simpler XLA program? If we could do that, we would have more likelihood of a useful bug report to XLA:GPU.
Thank you for the reply! Unfortunately I need the x64 precision in my application. For now I managed to bypass it by using the custom complex multiplication as in the latter part of the example and was able to get a ~8x speedup.
I was not able to isolate the issue to a complex128 multiplication and the following scan or loop (or just a long list of commands?) seems to be necessary. Doing only a element multiplication and calculating its gradient work just fine.
I think this problem is gone in recent jax versions. I will close the issue.