Merging with sparsediffax?
Hi @mfschubert, thanks for developing this package!
Lately I have worked a lot on sparse autodiff in the Julia language, and some of the products of my work are useful for Python too. In particular, the package SparseMatrixColorings.jl offers fast coloring and decompression for Jacobians and Hessians. I wanted people to be able to use it in JAX, so I created a Python interface called sparsediffax (which calls pysparsematrixcolorings). It was inspired by sparsejac but has some key differences vis-a-vis coloring and symmetry-handling, which I would be happy to explain in more detail.
It is my very first Python package, and I probably screwed up in many places. Conversely, sparsejac has a lot of stuff that I like: better tests, more thorough input validation, an API closer to that of JAX, and an experienced Python developer at its helm. Would you be interested in joining forces, so that there is a single package for sparse autodiff in JAX? I don't care where it lives or what it is called, as long as we're both coauthors.
Hi @gdalle, nice work with sparsediffax. Your proposal is quite appealing---I don't have significant bandwidth to devote to this at the moment, but I am broadly interested in advancing the available tools in this space. And, I am expecting to start a project in the medium term which will eventually raise the priority for me personally. Therefore I would generally be happy to work with you on this goal.
To start, I am definitely eager to hear about the key differences between our packages, and your ideas on how these packages should evolve.
The main differences I see are related to coloring, and I will describe them by referring to the terminology and insights in this (excellent) review:
What Color is your Jacobian?, Gebremedhin, Manne and Pothen (2005)
It provided the basis for much of the work that led to the ColPack implementation in C++, and to our Julia reimplementation SparseMatrixColorings.jl. You can also find more detail about that package in our recent preprint:
Revisiting Sparse Matrix Coloring and Bicoloring, Montoison, Dalle and Gebremedhin (2025)
And if you're curious about the state of sparse autodiff in Julia, you can read our other works:
A Common Interface for Automatic Differentiation, Dalle and Hill (2025) Sparser, Better, Faster, Stronger: Efficient Automatic Differentiation for Sparse Jacobians and Hessians, Hill and Dalle (2025)
Graph representation
Let $S$ denote the sparsity pattern of the Jacobian: sparsejac uses $S S^\top$ or $S^\top S$ as the adjacency matrix for the graph on which coloring is run. These are known as the row intersection graph and column intersection graph respectively, and they are much denser than a third option called the bipartite graph. That one simply creates vertices for rows and columns of $S$, with an edge whenever the corresponding coefficient is non-zero, and it is the representation we chose in SparseMatrixColorings.jl. It doesn't change the subsequent coloring, but it does improve the efficiency of graph construction and storage by a factor that is roughly proportional to the average degree.
julia> using SparseArrays
julia> S = sprand(Bool, 1000, 1000, 0.01) # example sparsity pattern
1000×1000 SparseMatrixCSC{Bool, Int64} with 9976 stored entries:
julia> nnz(S) # size of the bipartite graph
9976
julia> nnz(transpose(S) * S) # size of the column intersection graph
95609
Coloring algorithm
Your choice of graph representation means you can get away with the simple distance-1 greedy coloring provided by networkx. Meanwhile, the more lightweight bipartite graph representation requires a different kind of coloring: a partial distance-2 coloring. Indeed, you only need to color the column or row vertices in the bipartite graph, and those are 2 hops apart from each other. SparseMatrixColorings.jl implements this custom coloring, along with efficient ways to compute relevant node orderings.
Symmetry handling
As the name suggests, sparsejac is built with Jacobian matrices in mind. For Hessian matrices, the appropriate graph encoding and coloring algorithm are different: one typically computes a star coloring on the adjacency graph associated with $S$. At a high level, this expressed the additional freedom we get during recovery: we can actually compress some overlapping columns together because if we don't retrieve a coefficient from $H_{ij}$, we might still get it from its counterpart $H_{ji}$.
This symmetry-aware approach is beneficial in general, and in some cases it is essential. Consider a function like $f(x) = \sum x_i^2 + x_1 \sum_i x_i$. Its Hessian has one dense row and one dense column. When we just compute it as the Jacobian of a gradient, forgetting symmetry, we are forced to use $n$ colors either way because all columns and all rows overlap. Conversely, the star coloring requires just two different colors.
Since the number of colors gives the number of necessary HVPs, this translates into a huge win (at least in theory, in practice sometimes jax.jit can save your ass).
julia> using SparseArrays, SparseMatrixColorings
julia> begin
S = spdiagm(ones(Bool, 10))
S[1, :] .= 1
S[:, 1] .= 1
S
end
10×10 SparseMatrixCSC{Bool, Int64} with 28 stored entries:
1 1 1 1 1 1 1 1 1 1
1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
1 ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅
1 ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅
1 ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅
1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅
1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅
1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1
julia> fast_coloring(S, ColoringProblem{:nonsymmetric,:column}(), GreedyColoringAlgorithm())
10-element Vector{Int64}:
1
2
3
4
5
6
7
8
9
10
julia> fast_coloring(S, ColoringProblem{:symmetric,:column}(), GreedyColoringAlgorithm())
10-element Vector{Int64}:
1
2
2
2
2
2
2
2
2
2
By a similar argument, Jacobian matrices with dense rows and columns are out of reach for standard, uni-directional coloring (which corresponds to either forward or reverse-mode autodiff). That's when bidirectional coloring comes into play, and SparseMatrixColorings.jl also offers that, allowing us to compute sparse Jacobians by mixing both modes.
Preparation for decompression
With these more sophisticated colorings, decompression becomes less obvious. In particular, we need to decide which color we will use to retrieve coefficient $(i, j)$: the row color $c_i$ or the column color $c_j$? For a fully vectorized decompression, this needs to be done at compile time. To enable that, sparsediffax reaches into SparseMatrixColorings.jl data structures and precomputes this indexing information, which was not needed in the unidirectional Jacobian case.
Sparsity detection
Ideally, sparsity detection would grab the jaxpr representation of a function and make magic with it. In practice, this is really hard, and although people are trying (see jax2sympy for one approach), I think it would be nice to provide a built-in tool, however rough. Right now, sparsediffax offers naive sparsity detectors, which just compute the Jacobian or Hessian densely at a given point and check its nonzero entries. I plan to refine it just a little bit to propagate NaNs instead of values through the function, thus negating the effect of control flow (or erroring appropriately) to ensure a globally-valid sparsity pattern.
What do you think?
Hi @gdalle , thanks for this great summary. These all sound like excellent additions. To summarize, it seems like things fit into two categories: user-facing changes and those which are under the hood. On the user side, these are
- adding a
hessianfunction to compute sparse hessians - adding a mechanism to compute sparsity pattern if none is given
Under the hood, we have improved graph representation, coloring, and all the extra logic needed to support hessians efficiently. Here, we can significantly leverage the Julia packages you have been developing.
With this in mind, I would propose the following.
- Julia packages which are leveraged get their own pip-installable python wrapper packages, so that e.g.
sparsejacwould depend onpysparsematrixcolorings. This seems like a good way to limit complexity and also enable future developers who might have other uses for these packages. - We create an
experimentalsubpackage withinsparsejac, where the new features and implementations can be incubated. Once the API settles down and we have good test coverage, we can move them out of experimental. Alternative implementations and other new capabilities can continue to be explored in experimental. - A refactor of
sparsejacinto submodules should be undertaken so that things are sensibly organized in the context of these new features.
Let me know what you think. In terms of the practicalities of working together, and coauthorship, I am open to suggestions.
To summarize, it seems like things fit into two categories: user-facing changes and those which are under the hood.
I agree with your summary.
Julia packages which are leveraged get their own pip-installable python wrapper packages, so that e.g. sparsejac would depend on pysparsematrixcolorings. This seems like a good way to limit complexity and also enable future developers who might have other uses for these packages.
Working on that for pysparsematrixcolorings, I'm a bit new at this so I'm not sure how best to translate the interface. I'm also discovering all the tooling around Python packaging and registration, so it might take a little while before everything is pip-friendly.
We create an experimental subpackage within sparsejac, where the new features and implementations can be incubated. Once the API settles down and we have good test coverage, we can move them out of experimental. Alternative implementations and other new capabilities can continue to be explored in experimental.
When I said I didn't care what the package was called or where it lived, that was a slight exaggeration ^^ I think sparsejac is not ideal as a name since we want to include Hessians. It also shows no connection to JAX, unlike many package names in that ecosystem which end with x.
On the other hand, I appreciate the importance of not releasing anything half-baked, and letting people who use sparsejac continue to use it undisturbed, at least until we figure things out. That's why I would favor experimental development in a separate repo, be it named sparsediffax or something else. Such a split would also allow me to play around get my hands dirty with CI setup, testing, docs, etc. without major consequences.
Let me know what you think. In terms of the practicalities of working together, and coauthorship, I am open to suggestions.
The core logic of sparsediffax is the same as that of sparsejac: compute a few JVPs/VJPs/HVPs (dictated by the color vector), then decompress them into a BCOO format. What I really like in sparsejac is the craftsmanship that goes around it: you handle more edge cases, have more tests and overall better code quality. I'd like to reuse that if possible, or at least take inspiration from it.
My suggestion would be to add an MIT license to sparsediffax with us both as authors, and an explicit reference in the README saying that some code was copied from sparsejac. With that, I think we should be okay on the legalese, at least software-wise. Then, once we get the new version to the level of quality we want to see, I would be enthusiastic to try it out on real cases, and why not write a small workshop paper together?
@gdalle This works for me! I am happy to have development take place in sparsediffax, and to point users toward that repo on the sparsejax homepage. And agreed that a workshop paper is a outcome worth targeting.