[REQ] High-Level Sparse Matrix Value Update API
Description
Warp currently supports constructing a sparse matrix using warp.sparse.bsr_from_triplets. However, updating the numerical values in a matrix while keeping the sparsity pattern fixed is not straightforward.
This feature request proposes adding a high-level API to update the values of a sparse matrix in-place, without exposing or requiring the user to manually reconstruct the internal (i, j) mapping for each entry.
Context
In many applications, such as FEM, physics simulation, or iterative solvers, the sparsity pattern remains constant across time steps or iterations, but the values change.
Currently, the only way (as far as I understand it) to update values without rebuilding the matrix is to:
- Call
A.uncompress_rows()to recover the row index for each nonzero. - Launch a custom kernel that takes
rows,A.columns, andA.values. - Recompute values using a custom
@wp.kernel.
This approach is low-level, can be fragile, and requires the user to manage index bookkeeping manually.
Example: Current Pattern
@wp.kernel
def update_values(rows, cols, vals):
tid = wp.tid() # thread/triplet index
i = rows[tid]
j = cols[tid]
vals[tid] = compute_new_value(i, j) # must be a @wp.func
wp.launch(
kernel=update_values,
dim=A.nnz,
inputs=[A.uncompress_rows(), A.columns, A.values],
)
Drawbacks:
- Requires knowledge of CSR internals (
uncompress_rows()) - Manual
(i, j)mapping may be error-prone - Doesn't communicate intent clearly
- Hard to reuse
Proposed API
Introduce a utility like:
warp.sparse.update_values(A, update_func)
Where:
update_func(i, j, old_value)returns the updated value for position(i, j)old_valuecould be optional if not needed- Internally, Warp would handle row/col indexing and launch an efficient kernel
This makes the common update pattern explicit, safe, and concise.
Optional Enhancements
- Configure whether to mutate the passed matrix or create a new one by passing
update_valuesanin_place=True/Falseflag
Benefits
- Greatly simplifies updating sparse matrix values
- Eliminates boilerplate kernel logic and index handling
- Aligns with Warp's design philosophy of expressive, GPU-friendly computation
- Enables clean code reuse in simulations, numerical solvers, etc.
- Encourages best practices (clear update semantics, immutability)
- Optimisation potential
Apologies for the delay in answering this request. In the end I did not add yet such an high-level API. It does not remove that much boilerplate, and difficulties arise when the user function needs additional parameters. There are also question of whether the current value should be passed to the update function, so I am a bit concerned about whether the reusability is worth the complexity.
Instead, in order to avoid having to decompress the row indices, I have added the bsr_row_index() function to the public API. This probably does not count as high-level, but should be the most composable
@wp.func
def compute_new_value(i: int, j: int) -> float:
return float(i + j)
@wp.kernel
def update_values(A: Any):
block_index = wp.tid()
row = wps.bsr_row_index(A.offsets, A.nrow, block_index)
# A.nnz on host is an upper bound for the actual number of non-zeros,
# so the non-zero at block_index may not exist.
# `bsr_row_index` will return `-1` in such cases
if row != -1:
col = A.columns[block_index]
A.values[block_index] = compute_new_value(row, col)
wp.launch(
kernel=update_values,
dim=A.nnz,
inputs=[A],
)