jax icon indicating copy to clipboard operation
jax copied to clipboard

spsolve exits with error when inverting matrix sum

Open buvoli opened this issue 1 year ago • 4 comments

Description

I am trying to solve a linear system (A1 + A2) x = b using sparse matrices A1, A2 and jax.experimental.sparse.linalg.spsolve. Since spsolve requires BCSR format which does not yet support addition, I am storing A1 and A2 in BCOO format, and then converting the sum in BCSR format. However this is causing an error in spsolve.

I can reproduce the problem with the following simplified code:

Runs Fine

import jax
import jax.experimental.sparse

I_boo = jax.experimental.sparse.eye(2)
I_csr = jax.experimental.sparse.BCSR.from_bcoo(I_boo)

b = jax.numpy.ones(2)
jax.experimental.sparse.linalg.spsolve(I_csr.data, I_csr.indices, I_csr.indptr, b)

Exits in Error

import jax
import jax.experimental.sparse

I_boo = jax.experimental.sparse.eye(2)
I_boo = I_boo + I_boo # runs fine however causes spsolve to fail
I_csr = jax.experimental.sparse.BCSR.from_bcoo(I_boo)

b = jax.numpy.ones(2)
jax.experimental.sparse.linalg.spsolve(I_csr.data, I_csr.indices, I_csr.indptr, b)

The error is copied below.

XlaRuntimeErrorTraceback (most recent call last)
<ipython-input-2-9aae69a4e181> in <module>
      7 
      8 b = jax.numpy.ones(2)
----> 9 jax.experimental.sparse.linalg.spsolve(I_csr.data, I_csr.indices, I_csr.indptr, b)

~/.local/lib/python3.9/site-packages/jax/experimental/sparse/linalg.py in spsolve(data, indices, indptr, b, tol, reorder)
    619     the sparse linear system.
    620   """
--> 621   return spsolve_p.bind(data, indices, indptr, b, tol=tol, reorder=reorder)

~/.local/lib/python3.9/site-packages/jax/_src/core.py in bind(self, *args, **params)
    420     assert (not config.enable_checks.value or
    421             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 422     return self.bind_with_trace(find_top_trace(args), args, params)
    423 
    424   def bind_with_trace(self, trace, args, params):

~/.local/lib/python3.9/site-packages/jax/_src/core.py in bind_with_trace(self, trace, args, params)
    423 
    424   def bind_with_trace(self, trace, args, params):
--> 425     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    426     return map(full_lower, out) if self.multiple_results else full_lower(out)
    427 

~/.local/lib/python3.9/site-packages/jax/_src/core.py in process_primitive(self, primitive, tracers, params)
    911 
    912   def process_primitive(self, primitive, tracers, params):
--> 913     return primitive.impl(*tracers, **params)
    914 
    915   def process_call(self, primitive, f, tracers, params):

~/.local/lib/python3.9/site-packages/jax/_src/dispatch.py in apply_primitive(prim, *args, **params)
     85     prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     86     try:
---> 87       outs = fun(*args)
     88     finally:
     89       lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 10 frame]

~/.local/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py in __call__(self, *args)
   1203         or self.has_host_callbacks):
   1204       input_bufs = self._add_tokens_to_inputs(input_bufs)
-> 1205       results = self.xla_executable.execute_sharded(
   1206           input_bufs, with_tokens=True
   1207       )

XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: Traceback (most recent call last):
  File "/usr/lib64/python3.9/runpy.py", line 197, in _run_module_as_main
  File "/usr/lib64/python3.9/runpy.py", line 87, in _run_code
  File "/usr/lib/python3.9/site-packages/ipykernel_launcher.py", line 16, in <module>
  File "/usr/lib/python3.9/site-packages/traitlets/config/application.py", line 845, in launch_instance
  File "/usr/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 612, in start
  File "/usr/lib64/python3.9/site-packages/tornado/platform/asyncio.py", line 199, in start
  File "/usr/lib64/python3.9/asyncio/base_events.py", line 596, in run_forever
  File "/usr/lib64/python3.9/asyncio/base_events.py", line 1890, in _run_once
  File "/usr/lib64/python3.9/asyncio/events.py", line 80, in _run
  File "/usr/lib64/python3.9/site-packages/tornado/ioloop.py", line 688, in <lambda>
  File "/usr/lib64/python3.9/site-packages/tornado/ioloop.py", line 741, in _run_callback
  File "/usr/lib64/python3.9/site-packages/tornado/gen.py", line 814, in inner
  File "/usr/lib64/python3.9/site-packages/tornado/gen.py", line 775, in run
  File "/usr/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 362, in process_one
  File "/usr/lib64/python3.9/site-packages/tornado/gen.py", line 234, in wrapper
  File "/usr/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 265, in dispatch_shell
  File "/usr/lib64/python3.9/site-packages/tornado/gen.py", line 234, in wrapper
  File "/usr/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 540, in execute_request
  File "/usr/lib64/python3.9/site-packages/tornado/gen.py", line 234, in wrapper
  File "/usr/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 302, in do_execute
  File "/usr/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 539, in run_cell
  File "/usr/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2886, in run_cell
  File "/usr/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2932, in _run_cell
  File "/usr/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
  File "/usr/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3155, in run_cell_async
  File "/usr/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3347, in run_ast_nodes
  File "/usr/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3427, in run_code
  File "<ipython-input-2-9aae69a4e181>", line 9, in <module>
  File "/home/user/.local/lib/python3.9/site-packages/jax/experimental/sparse/linalg.py", line 621, in spsolve
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 422, in bind
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 425, in bind_with_trace
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 913, in process_primitive
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 298, in cache_miss
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 176, in _python_pjit_helper
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 2788, in bind
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 425, in bind_with_trace
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/core.py", line 913, in process_primitive
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 1488, in _pjit_call_impl
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 1471, in call_impl_cache_miss
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/pjit.py", line 1427, in _pjit_call_impl_python
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 335, in wrapper
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 1205, in __call__
  File "/home/user/.local/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py", line 2466, in _wrapped_callback
  File "/home/user/.local/lib/python3.9/site-packages/jax/experimental/sparse/linalg.py", line 547, in _callback
  File "/home/user/.local/lib/python3.9/site-packages/scipy/sparse/linalg/_dsolve/linsolve.py", line 242, in spsolve
  File "/home/user/.local/lib/python3.9/site-packages/scipy/sparse/_compressed.py", line 1125, in sum_duplicates
ValueError: WRITEBACKIFCOPY base is read-only

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.9.7 (default, Aug 30 2021, 00:00:00)  [GCC 11.2.1 20210728 (Red Hat 11.2.1-1)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='959df832f7d6', release='5.10.104-linuxkit', version='#1 SMP Thu Mar 17 17:08:06 UTC 2022', machine='x86_64')

buvoli avatar May 01 '24 20:05 buvoli

Thanks for the report! It looks like something in spsolve fails in the presence of duplicate indices. You can fix the issue in your case by doing this before passing the I_csr buffers to spsolve:

I_csr = I_csr.sum_duplicates()

jakevdp avatar May 03 '24 18:05 jakevdp

I suspect this is working as intended, since spsolve is a lower-level function, though we should do a better job of documenting its requirements.

jakevdp avatar May 03 '24 18:05 jakevdp

Thank you, adding sum_duplicates resolved the issue! It would be great if this eventually gets automatically taken care of by Jax so that sparse matrices behaves in the same way as scipy.

buvoli avatar May 04 '24 20:05 buvoli

I don't think we'll ever automatically take care of this, because deduplication is a relatively expensive operation (even detecting whether deduplication is necessary is expensive!), and people calling a low-level routine like spsolve generally care about performance.

jakevdp avatar May 05 '24 14:05 jakevdp