cellrank icon indicating copy to clipboard operation
cellrank copied to clipboard

Speed up the `_partition` utility

Open Marius1311 opened this issue 2 years ago • 2 comments

As we saw recently when we were debugging @WeilerP's data example, it makes sense to have a utility to learn something about transient and recurrent classes in a Markov chain. Currently, that's done by the _partition. utility function. Compute time can be decomposed into the following three tasks:

  1. Initialise the directed graph using nx.DiGraph(conn). As we saw in @WeilerP's example, this only takes a few seconds, even for 70k cells.
  2. Decompose the state space into strongly connected components (SCCs) (called communication classes in Markov chain language) using nx.strongly_connected_components(g). Under the hood, this uses Tarjan's algorithm or some variants of it, which runs in linear time. Again, we saw in @WeilerP's example that this is very fast for 70k cells
  3. For each SCC, figure out whether it's recurrent or transient. That's by far the most expensive component of this utility.

Let's focus on the third point, since it's the bottleneck. In code, currently solve 3 as follows:

    def partition(g):
        yield from (
            (
                (sorted(scc) if sort else scc),
                all((not nx.has_path(g, s, t) for s, t in product(scc, g.nodes - scc))),
            )
            for scc in nx.strongly_connected_components(g)
        )

I think this can be done much more efficiently when working directly with the transition matrix. Given the transition matrix T and S index sets for SCCs {I_s}_{i=1}^S, where index set I_s holds the indices of nodes assigned to SCC s, you simply have to check whether rows corresponding to I_s in T have non-zero elements in columns which are not in I_s. If that's the case, SCC s is transient, otherwise it's recurrent. In other words, you try to rearrange T into block-diagonal form;, every SCC where this works is recurrent, every SCC where this fails is transient. That should be much more efficient, I think.

Alternatively, for every SCC s, you could also compute the sum over all T_{s, I_s}. If that's equal 1, then s is recurrent, if not, then it's transient. However, you have to be careful here because of numerical precision, we know rows in T sometimes don't exactly sum to one and we usually check for this with some thresholds, so you would have to be a bit careful with setting thresholds here.

All in all, this is not super high priority, but I think it's an easy fix that's useful for debugging.

Minor points

  • parameter sort not described in the docstring.
  • Parameter description of conn: "Directed graph to partition.", there's an unnecessary "" here.

Marius1311 avatar Dec 06 '21 08:12 Marius1311

@michalk8, do you think it's worthwhile to implement this? Should be quite an easy enhancement and quite useful for debugging. Your call.

Marius1311 avatar Mar 04 '22 02:03 Marius1311

Hi @michalk8, is the partition utility currently still part of our API? I would vote strongly in favor for keeping it.

Marius1311 avatar Jun 21 '22 12:06 Marius1311

Unfortunately, I won't have time to look into this in the near future; also think this is really niche. Closing due to inactivity and low-priority.

michalk8 avatar Jun 07 '23 01:06 michalk8

I'll get back to this when we need it, but it's good to close for now.

Marius1311 avatar Jun 07 '23 12:06 Marius1311