warp icon indicating copy to clipboard operation
warp copied to clipboard

[REQ] High-Level Sparse Matrix Value Update API

Open akrivx opened this issue 4 months ago • 1 comments

Description

Warp currently supports constructing a sparse matrix using warp.sparse.bsr_from_triplets. However, updating the numerical values in a matrix while keeping the sparsity pattern fixed is not straightforward.

This feature request proposes adding a high-level API to update the values of a sparse matrix in-place, without exposing or requiring the user to manually reconstruct the internal (i, j) mapping for each entry.

Context

In many applications, such as FEM, physics simulation, or iterative solvers, the sparsity pattern remains constant across time steps or iterations, but the values change.

Currently, the only way (as far as I understand it) to update values without rebuilding the matrix is to:

  1. Call A.uncompress_rows() to recover the row index for each nonzero.
  2. Launch a custom kernel that takes rows, A.columns, and A.values.
  3. Recompute values using a custom @wp.kernel.

This approach is low-level, can be fragile, and requires the user to manage index bookkeeping manually.

Example: Current Pattern

@wp.kernel
def update_values(rows, cols, vals):
    tid = wp.tid()  # thread/triplet index
    i = rows[tid]
    j = cols[tid]
    vals[tid] = compute_new_value(i, j)  # must be a @wp.func


wp.launch(
    kernel=update_values,
    dim=A.nnz,
    inputs=[A.uncompress_rows(), A.columns, A.values],
)

Drawbacks:

  • Requires knowledge of CSR internals (uncompress_rows())
  • Manual (i, j) mapping may be error-prone
  • Doesn't communicate intent clearly
  • Hard to reuse

Proposed API

Introduce a utility like:

warp.sparse.update_values(A, update_func)

Where:

  • update_func(i, j, old_value) returns the updated value for position (i, j)
  • old_value could be optional if not needed
  • Internally, Warp would handle row/col indexing and launch an efficient kernel

This makes the common update pattern explicit, safe, and concise.

Optional Enhancements

  • Configure whether to mutate the passed matrix or create a new one by passing update_values an in_place=True/False flag

Benefits

  • Greatly simplifies updating sparse matrix values
  • Eliminates boilerplate kernel logic and index handling
  • Aligns with Warp's design philosophy of expressive, GPU-friendly computation
  • Enables clean code reuse in simulations, numerical solvers, etc.
  • Encourages best practices (clear update semantics, immutability)
  • Optimisation potential

akrivx avatar Aug 06 '25 15:08 akrivx

Apologies for the delay in answering this request. In the end I did not add yet such an high-level API. It does not remove that much boilerplate, and difficulties arise when the user function needs additional parameters. There are also question of whether the current value should be passed to the update function, so I am a bit concerned about whether the reusability is worth the complexity.

Instead, in order to avoid having to decompress the row indices, I have added the bsr_row_index() function to the public API. This probably does not count as high-level, but should be the most composable

@wp.func
def compute_new_value(i: int, j: int) -> float:
    return float(i + j)


@wp.kernel
def update_values(A: Any):
    block_index = wp.tid()

    row = wps.bsr_row_index(A.offsets, A.nrow, block_index)
    # A.nnz on host is an upper bound for the actual number of non-zeros,
    # so the non-zero at block_index may not exist. 
    # `bsr_row_index` will return `-1` in such cases
    if row != -1: 
        col = A.columns[block_index]
        A.values[block_index] = compute_new_value(row, col)


wp.launch(
    kernel=update_values,
    dim=A.nnz,
    inputs=[A],
)

gdaviet avatar Oct 21 '25 15:10 gdaviet