lineax
lineax copied to clipboard
Batched JAX linear solves bugged for large batches
Hello,
I opened a similar issue on the main JAX (https://github.com/google/jax/issues/19431) but I thought it may get more attention here.
The batched JAX linear solves seem to be bugged for large batches on GPU, even if it can still comfortably fit in GPU memory. In short, if you try to solve a bunch of linear system, then the JAX LU/Cholesky solver will sometime return NaN's/other problems but not throw an error or warning. The SVD-based solve seems to work better, though it also fails if you get close enough to filling the full GPU memory. The QR-based solve is too slow for me to test at large batch size, strangely. The lineax solves has the same behavior, although it does throw an error upon seeing NaNs.
Below is a test and output, where solving Ax = b where A is the identity and b is all ones returns NaNs. I am curious if someone can reproduce this behavior and has any ideas on what to do.
Thank you for making this nice library! Best, Marc
import jax
import lineax as lx
import jax.numpy as jnp
from jax import random
device = jax.local_devices()[0]
print('on device:', device)
m = 10
batched_solve_lu = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.LU()).value)
batched_solve_SVD = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.SVD()).value)
solve_fns = [jax.scipy.linalg.solve, batched_solve_SVD, batched_solve_lu]
for solve_fn in solve_fns:
for n in [ int(1e6), int(1e7)]:
A = jnp.repeat(jnp.identity(m)[None], n, axis = 0)
x = jnp.ones([n,m])
b = jax.lax.batch_matmul(A,x[...,None])[...,0]
x_solved = solve_fn(A,b)
print(f"Average error with n ={n}, {jnp.mean(jnp.linalg.norm(x - x_solved, axis=-1))} ")
print("Memory info ", device.memory_stats())
Output:
on device: cuda:0
Average error with n =1000000, 0.0
Memory info {'bytes_in_use': 520001024, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 804000512, 'largest_free_block_bytes': 0, 'num_allocs': 29, 'peak_bytes_in_use': 1324001536, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =10000000, nan
Memory info {'bytes_in_use': 5200002048, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 58, 'peak_bytes_in_use': 13280002560, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =1000000, 0.0
Memory info {'bytes_in_use': 520000000, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 98, 'peak_bytes_in_use': 13280002560, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =10000000, 0.0
Memory info {'bytes_in_use': 5200000000, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 136, 'peak_bytes_in_use': 18120001024, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =1000000, 0.0
Memory info {'bytes_in_use': 560002560, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 174, 'peak_bytes_in_use': 18120001024, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
2024-01-24 15:17:43.130711: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: CpuCallback error: EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.
If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.
If you *were* expecting this solver to work with this operator, then it may be because:
(a) the operator is singular, and your code has a bug; or
(b) the operator was nearly singular (i.e. it had a high condition number:
`jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
numerical instability issues; or
(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
that is does not actually satisfy.
At:
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_errors.py(70): raises
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(258): _flat_callback
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(52): pure_callback_impl
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(188): _callback
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py(2267): _wrapped_callback
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py(1152): __call__
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1151): _pjit_call_impl_python
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1195): call_impl_cache_miss
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1211): _pjit_call_impl
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(869): process_primitive
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1427): _pjit_batcher
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/batching.py(433): process_primitive
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(166): _python_pjit_helper
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(255): cache_miss
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(200): _call
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_module.py(935): __call__
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(206): __call__
/tmp/ipykernel_3270959/20536111.py(9): <lambda>
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/linear_util.py(190): call_wrapped
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/api.py(1260): vmap_f
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
/tmp/ipykernel_3270959/20536111.py(22): <module>
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3526): run_code
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3466): run_ast_nodes
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3284): run_cell_async
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3079): _run_cell
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3024): run_cell
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/zmqshell.py(546): run_cell
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/ipkernel.py(422): do_execute
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(740): execute_request
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(412): dispatch_shell
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(505): process_one
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(516): dispatch_queue
/home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/events.py(80): _run
/home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(1905): _run_once
/home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(601): run_forever
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/tornado/platform/asyncio.py(195): start
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelapp.py(736): start
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/traitlets/config/application.py(1053): launch_instance
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel_launcher.py(17): <module>
/home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(87): _run_code
/home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(197): _run_module_as_main
2024-01-24 15:17:43.130798: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2711] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.
If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.
If you *were* expecting this solver to work with this operator, then it may be because:
(a) the operator is singular, and your code has a bug; or
(b) the operator was nearly singular (i.e. it had a high condition number:
`jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
numerical instability issues; or
(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
that is does not actually satisfy.
At:
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_errors.py(70): raises
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(258): _flat_callback
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(52): pure_callback_impl
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(188): _callback
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py(2267): _wrapped_callback
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py(1152): __call__
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1151): _pjit_call_impl_python
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1195): call_impl_cache_miss
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1211): _pjit_call_impl
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(869): process_primitive
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1427): _pjit_batcher
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/batching.py(433): process_primitive
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(166): _python_pjit_helper
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(255): cache_miss
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(200): _call
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_module.py(935): __call__
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(206): __call__
/tmp/ipykernel_3270959/20536111.py(9): <lambda>
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/linear_util.py(190): call_wrapped
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/api.py(1260): vmap_f
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
/tmp/ipykernel_3270959/20536111.py(22): <module>
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3526): run_code
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3466): run_ast_nodes
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3284): run_cell_async
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3079): _run_cell
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3024): run_cell
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/zmqshell.py(546): run_cell
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/ipkernel.py(422): do_execute
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(740): execute_request
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(412): dispatch_shell
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(505): process_one
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(516): dispatch_queue
/home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/events.py(80): _run
/home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(1905): _run_once
/home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(601): run_forever
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/tornado/platform/asyncio.py(195): start
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelapp.py(736): start
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/traitlets/config/application.py(1053): launch_instance
/home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel_launcher.py(17): <module>
/home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(87): _run_code
/home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(197): _run_module_as_main
; current tracing scope: custom-call.101; current profiling annotation: XlaModule:#hlo_module=jit_linear_solve,program_id=40#.
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[1], line 22
19 x = jnp.ones([n,m])
20 b = jax.lax.batch_matmul(A,x[...,None])[...,0]
---> 22 x_solved = solve_fn(A,b)
23 print(f"Average error with n ={n}, {jnp.mean(jnp.linalg.norm(x - x_solved, axis=-1))} ")
24 print("Memory info ", device.memory_stats())
[... skipping hidden 3 frame]
Cell In[1], line 9, in <lambda>(matrix, vector)
5 print('on device:', device)
7 m = 10
----> 9 batched_solve_lu = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.LU()).value)
10 batched_solve_SVD = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.SVD()).value)
12 solve_fns = [jax.scipy.linalg.solve, batched_solve_SVD, batched_solve_lu]
[... skipping hidden 14 frame]
File ~/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py:1152, in ExecuteReplicated.__call__(self, *args)
1150 self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
1151 else:
-> 1152 results = self.xla_executable.execute_sharded(input_bufs)
1153 if dispatch.needs_check_special():
1154 out_arrays = results.disassemble_into_single_device_arrays()
XlaRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.
If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.
If you *were* expecting this solver to work with this operator, then it may be because:
(a) the operator is singular, and your code has a bug; or
(b) the operator was nearly singular (i.e. it had a high condition number:
`jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
numerical instability issues; or
(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
that is does not actually satisfy.
-------
This error occurred during the runtime of your JAX program. Setting the environment
variable `EQX_ON_ERROR=breakpoint` is usually the most useful way to debug such errors.
(This can be navigated using most of the usual commands for the Python debugger:
`u` and `d` to move through stack frames, the name of a variable to print its value,
etc.) See also `[https://docs.kidger.site/equinox/api/errors/#equinox.error_if`](https://docs.kidger.site/equinox/api/errors/#equinox.error_if%60) for more
information.
Hmm. So JAX and Lineax both basically do the same thing for the LU/QR/Cholesky solvers, which is to use the JAX (and thus probably CUDA) implementation of those decompositions.
The fact that the QR solve is slow is expected I think -- IIRC there's no CUDA implementation of a batched QR decomposition, so vmap is handled by computing the decomposition for each batch element sequentially.
I suspect the issue is probably somewhere in the underlying CUDA (cuSolver?) implementations. I think resolving this will probably need someone to go digging through things at that level, I'm afraid.
Hi Patrick,
Thank for your answer!
I can't say I really understand how JAX/torch/cupy interact with CUDA code, but what is surprising to me is that this seems to be a bug only in JAX. Both torch/cupy seem to work, even though I would assume they use the same backend.
E.g.:
import numpy as np
import torch
n = int(1e7); m = 10
A = torch.tensor(np.repeat(np.identity(m)[None], n, axis = 0))
L = torch.linalg.cholesky(A)
print(torch.linalg.norm(A - L))
Outputs:
tensor(0., dtype=torch.float64)
And the same thing for cupy, but JAX returns NaNs.
Oh interesting! Hmm, in that case I'm less certain of the reason. Maybe check that it's not a version issue? PyTorch and JAX tend to use different versions of the underlying NVIDIA libraries.
Thanks for the suggestion! I tried a few different versions of CUDA without changes, but updating jax seems to fix the problem, or at least it passes the few tests I have tried.
Curious! Well, I'm glad it's fixed. :) Possibly an issue with a particular version of jaxlib then, if updating the version fixed things.