jax
jax copied to clipboard
spsolve exits with error when inverting matrix sum
Description
I am trying to solve a linear system (A1 + A2) x = b using sparse matrices A1, A2 and jax.experimental.sparse.linalg.spsolve. Since spsolve requires BCSR format which does not yet support addition, I am storing A1 and A2 in BCOO format, and then converting the sum in BCSR format. However this is causing an error in spsolve.
I can reproduce the problem with the following simplified code:
Runs Fine
import jax
import jax.experimental.sparse
I_boo = jax.experimental.sparse.eye(2)
I_csr = jax.experimental.sparse.BCSR.from_bcoo(I_boo)
b = jax.numpy.ones(2)
jax.experimental.sparse.linalg.spsolve(I_csr.data, I_csr.indices, I_csr.indptr, b)
Exits in Error
import jax
import jax.experimental.sparse
I_boo = jax.experimental.sparse.eye(2)
I_boo = I_boo + I_boo # runs fine however causes spsolve to fail
I_csr = jax.experimental.sparse.BCSR.from_bcoo(I_boo)
b = jax.numpy.ones(2)
jax.experimental.sparse.linalg.spsolve(I_csr.data, I_csr.indices, I_csr.indptr, b)
The error is copied below.
XlaRuntimeErrorTraceback (most recent call last)
<ipython-input-2-9aae69a4e181> in <module>
7
8 b = jax.numpy.ones(2)
----> 9 jax.experimental.sparse.linalg.spsolve(I_csr.data, I_csr.indices, I_csr.indptr, b)
~/.local/lib/python3.9/site-packages/jax/experimental/sparse/linalg.py in spsolve(data, indices, indptr, b, tol, reorder)
619 the sparse linear system.
620 """
--> 621 return spsolve_p.bind(data, indices, indptr, b, tol=tol, reorder=reorder)
~/.local/lib/python3.9/site-packages/jax/_src/core.py in bind(self, *args, **params)
420 assert (not config.enable_checks.value or
421 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 422 return self.bind_with_trace(find_top_trace(args), args, params)
423
424 def bind_with_trace(self, trace, args, params):
~/.local/lib/python3.9/site-packages/jax/_src/core.py in bind_with_trace(self, trace, args, params)
423
424 def bind_with_trace(self, trace, args, params):
--> 425 out = trace.process_primitive(self, map(trace.full_raise, args), params)
426 return map(full_lower, out) if self.multiple_results else full_lower(out)
427
~/.local/lib/python3.9/site-packages/jax/_src/core.py in process_primitive(self, primitive, tracers, params)
911
912 def process_primitive(self, primitive, tracers, params):
--> 913 return primitive.impl(*tracers, **params)
914
915 def process_call(self, primitive, f, tracers, params):
~/.local/lib/python3.9/site-packages/jax/_src/dispatch.py in apply_primitive(prim, *args, **params)
85 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
86 try:
---> 87 outs = fun(*args)
88 finally:
89 lib.jax_jit.swap_thread_local_state_disable_jit(prev)
[... skipping hidden 10 frame]
~/.local/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py in __call__(self, *args)
1203 or self.has_host_callbacks):
1204 input_bufs = self._add_tokens_to_inputs(input_bufs)
-> 1205 results = self.xla_executable.execute_sharded(
1206 input_bufs, with_tokens=True
1207 )
XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: Traceback (most recent call last):
File "/usr/lib64/python3.9/runpy.py", line 197, in _run_module_as_main
File "/usr/lib64/python3.9/runpy.py", line 87, in _run_code
File "/usr/lib/python3.9/site-packages/ipykernel_launcher.py", line 16, in <module>
File "/usr/lib/python3.9/site-packages/traitlets/config/application.py", line 845, in launch_instance
File "/usr/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 612, in start
File "/usr/lib64/python3.9/site-packages/tornado/platform/asyncio.py", line 199, in start
File "/usr/lib64/python3.9/asyncio/base_events.py", line 596, in run_forever
File "/usr/lib64/python3.9/asyncio/base_events.py", line 1890, in _run_once
File "/usr/lib64/python3.9/asyncio/events.py", line 80, in _run
File "/usr/lib64/python3.9/site-packages/tornado/ioloop.py", line 688, in <lambda>
File "/usr/lib64/python3.9/site-packages/tornado/ioloop.py", line 741, in _run_callback
File "/usr/lib64/python3.9/site-packages/tornado/gen.py", line 814, in inner
File "/usr/lib64/python3.9/site-packages/tornado/gen.py", line 775, in run
File "/usr/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 362, in process_one
File "/usr/lib64/python3.9/site-packages/tornado/gen.py", line 234, in wrapper
File "/usr/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 265, in dispatch_shell
File "/usr/lib64/python3.9/site-packages/tornado/gen.py", line 234, in wrapper
File "/usr/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 540, in execute_request
File "/usr/lib64/python3.9/site-packages/tornado/gen.py", line 234, in wrapper
File "/usr/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 302, in do_execute
File "/usr/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 539, in run_cell
File "/usr/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2886, in run_cell
File "/usr/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2932, in _run_cell
File "/usr/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
File "/usr/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3155, in run_cell_async
File "/usr/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3347, in run_ast_nodes
File "/usr/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3427, in run_code
File "<ipython-input-2-9aae69a4e181>", line 9, in <module>
File "/home/user/.local/lib/python3.9/site-packages/jax/experimental/sparse/linalg.py", line 621, in spsolve
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 422, in bind
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 425, in bind_with_trace
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 913, in process_primitive
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 298, in cache_miss
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 176, in _python_pjit_helper
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 2788, in bind
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 425, in bind_with_trace
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 913, in process_primitive
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 1488, in _pjit_call_impl
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 1471, in call_impl_cache_miss
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 1427, in _pjit_call_impl_python
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 335, in wrapper
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 1205, in __call__
File "/home/user/.local/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py", line 2466, in _wrapped_callback
File "/home/user/.local/lib/python3.9/site-packages/jax/experimental/sparse/linalg.py", line 547, in _callback
File "/home/user/.local/lib/python3.9/site-packages/scipy/sparse/linalg/_dsolve/linsolve.py", line 242, in spsolve
File "/home/user/.local/lib/python3.9/site-packages/scipy/sparse/_compressed.py", line 1125, in sum_duplicates
ValueError: WRITEBACKIFCOPY base is read-only
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.9.7 (default, Aug 30 2021, 00:00:00) [GCC 11.2.1 20210728 (Red Hat 11.2.1-1)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='959df832f7d6', release='5.10.104-linuxkit', version='#1 SMP Thu Mar 17 17:08:06 UTC 2022', machine='x86_64')
Thanks for the report! It looks like something in spsolve fails in the presence of duplicate indices. You can fix the issue in your case by doing this before passing the I_csr buffers to spsolve:
I_csr = I_csr.sum_duplicates()
I suspect this is working as intended, since spsolve is a lower-level function, though we should do a better job of documenting its requirements.
Thank you, adding sum_duplicates resolved the issue! It would be great if this eventually gets automatically taken care of by Jax so that sparse matrices behaves in the same way as scipy.
I don't think we'll ever automatically take care of this, because deduplication is a relatively expensive operation (even detecting whether deduplication is necessary is expensive!), and people calling a low-level routine like spsolve generally care about performance.