deeptime icon indicating copy to clipboard operation
deeptime copied to clipboard

Sort Markov matrix

Open cap-jmk opened this issue 2 years ago • 21 comments

Is your feature request related to a problem? Please describe. I am doing Markov modelling for SAR/QSAR analysis of chemical compounds and would need sorted markov matrices.

I suggest to sort the Markov matrix according to the most stable state. Something like with better memory management:

def sort_markov_matrix(markov_matrix): 
    """Takes in random markov matrix
    returns sorted markov matrix 

    Args:
        markov_matrix (np.array): unsorted matrix

    Returns:
        (np.array): sorted Markov matrix
    """
    
    
    
    b = markov_matrix.copy()
    for i in range(len(markov_matrix)): 
        ref1 = markov_matrix[i,i]
        for j in range(i+1, len(markov_matrix)): 
            ref2 = markov_matrix[j, j]
            if ref2 > ref1: 
                markov_matrix[i, :] = b[j, :]
                markov_matrix[j, :] = b[i, :]
                b = markov_matrix.copy()
                for k in range(len(markov_matrix)):
                    markov_matrix[k,i] = b[k, j]
                    markov_matrix[k,j] = b[k, i]
                    b = markov_matrix.copy()
    return markov_matrix

Test with


def test_sort(): 
    a = np.array([[0.8, 0.1, 0.05, 0.05],[0.005, 0.9, 0.03, 0.015], [0.1, 0.2, 0.4, 0.3],[0.01, 0.02, 0.03, 0.94]])
    sorted_a = sort_markov_matrix(a)
    assert np.array_equal(sorted_a[0,:], np.array([0.94, 0.02, 0.01, 0.03])) == True, str(sorted_a[0,:])
    assert np.array_equal(sorted_a[1,:], np.array([0.015,0.9, 0.005, 0.03])) == True, str(sorted_a[1,:])
    assert np.array_equal(sorted_a[2,:], np.array([0.05, 0.1, 0.8, 0.05])) == True, str(sorted_a[2,:])
    assert np.array_equal(sorted_a[3,:], np.array([0.3, 0.2, 0.1, 0.4])) == True, str(sorted_a[3,:])

What do you think?

cap-jmk avatar Mar 07 '22 15:03 cap-jmk

Hi, I think this may be a bit too specific to implement it as a default. There are different ways of understanding stability of a Markov state I would say, for instance you could

  • look at the main diagonal of the transition matrix as you did, or
  • you could look at the probability distribution in the stationary process,
  • you could also think about stable groups of states (like in PCCA+)

In that sense it might be better for each user to implement their own version of such relabeling. A more efficient variant of yours could for instance be implemented with permutation matrices. :slightly_smiling_face:

so...

msm = estimate_msm(data)
msm_sorted = deeptime.markov.msm.MarkovStateModel(sort_msm(msm.transition_matrix))

clonker avatar Mar 09 '22 09:03 clonker

Maybe you are right, however, it felt like it belonged to the overall Markov modelling which is part of the deeptime package. I can for sure implement it from my side, however, it feels strange. Implementation in deeptime would also improve readability of dependencies or the code in general. I.e.

msm = estimate_msm(data, sorted=True)

I think the TSM gives a good estimate of the states and serves like a fingerprint for my case. PCCA+ seems like a good idea and could be helpful in some cases.

How would you do it with permutation matrices? The best we could go would be O(n), right? How would the memory consumption look like for permutation matrices? I remember some application from solving linear systems with these matrices. Would your solution be similar?

cap-jmk avatar Mar 10 '22 15:03 cap-jmk

To my knowedge there is no canonical way of sorting Markov states, so I do not think it is a good idea to make this a True/False decision. What could be done is offer a relabeling function such as

msm_sorted = msm.relabel(np.argsort(np.diag(msm.transition_matrix)), inplace=False)

in your case. I do not have the capacity to implement this right now but am happy to give pointers and work on pull requests with you. There are multiple layers to this, though. In particular we'll have to be very careful this doesn't break any other parts of the library where there are assumptions on the Markov states staying the same over the course of taking submodels (for example when restricting yourself to the largest connected component in terms of jump probability connectivity graph). Also there are the following cases to keep in mind:

  • Markov model without statistics: this should be relatively straightforward
  • Markov model with statistics: here we have to be careful to also relabel the statistics to keep everything consistent
  • Effect on Markov state model collections (in particular MEMMs)
  • Effect on hidden Markov models
  • Implementation on sparse transition matrices / count matrices

There are probably more things to keep in mind here. In any case I think the easiest for you is to really sort the matrix on your own and create a new MSM instance.

Regarding permutation matrices: Yes, we cannot get better than O(n), but we can achieve vectorization.

clonker avatar Mar 11 '22 09:03 clonker

Yes, I know what you mean. I will give my best to support you.

cap-jmk avatar Mar 13 '22 11:03 cap-jmk

Cool, thanks! :rocket: I think a good first step would be reordering count matrices (in TransitionCountModel). Do you want to have a stab at that? I am still not entirely sure what such a method should be called, as it's not really a sorting but rather a relabeling - in general at least. Perhaps permute? Or reorder? First I thought transpose might be a good fit but that is really more used in the context of axes.

clonker avatar Mar 14 '22 09:03 clonker

@clonker, yes i could give it a try. Where do you want to change something? I would sort it in deeptime/markov/_transition_counting.py

cap-jmk avatar Mar 28 '22 12:03 cap-jmk

yes that would be a good start!

clonker avatar Mar 28 '22 14:03 clonker

Nice. Okay, I got a working sorting algorithm implemented. However, I would love that you review it before I start implementing it in deeptime. I don't know why, but I could only make it work with bubble sort on the diagonal.

def sort_markov_matrix(markov_matrix):
    """Takes in random markov matrix
    returns sorted markov matrix
    Args:
        markov_matrix (np.array): unsorted matrix
    Returns:
        (np.array): sorted Markov matrix
    """
    diag = np.diag(markov_matrix)
    sorting = np.argsort(diag)
    for i in range(len(diag)):
        for j in range(len(diag) - 1):
            if diag[j + 1] > diag[j]:
                markov_matrix[[j, j + 1]] = markov_matrix[[j + 1, j]]
                markov_matrix[:, [j, j + 1]] = markov_matrix[:, [j + 1, j]]
    return markov_matrix

cap-jmk avatar Apr 13 '22 14:04 cap-jmk

So here is a version for dense matrices, ideally we would support both dense and sparse though:

import numpy as np
from deeptime.markov.msm import MarkovStateModel

P = np.random.uniform(0, 1, size=(5, 5))
P /= P.sum(1)[:, None]
msm = MarkovStateModel(P)

diag = np.diag(msm.transition_matrix)
sorting = np.argsort(diag)[::-1]

perm = np.eye(len(sorting), dtype=msm.transition_matrix.dtype)[sorting]
msm_reordered = MarkovStateModel(np.linalg.multi_dot((perm, msm.transition_matrix, perm.T)))

clonker avatar Apr 14 '22 07:04 clonker

I see what you meant. With the multi-dot, you would always do Θ(2n) operations, whereas if you implement the sorting manually, you would do O(sqrt(n)) operations. Or am I overlooking something?

cap-jmk avatar Apr 14 '22 12:04 cap-jmk

~Nope not overlooking anything.~ While it probably warrants a benchmark, I would imagine that multi dot outperforms manual sorting in Python though. Things are different if you implement the sorting in an extension.

Edit: Actually matrix multiplications are (naively) Θ(n^3). In any case, here we can see that complexity =/= efficiency. 🙂

clonker avatar Apr 14 '22 12:04 clonker

bench

clonker avatar Apr 14 '22 13:04 clonker

Totally enlightening. Just, as you have the benchmark already written, I would be interested how it goes when we go beyond 10k samples. It's where things get messy usually.

cap-jmk avatar Apr 15 '22 12:04 cap-jmk

To satisfy your curiosity: bench

Now what would be interesting is the scaling behavior against a c/c++ coded sorting extension and against sparse matrices. Estimating a dense transition matrix with 10k Markov states is a tough task anyways because of the massive amounts of data you'd need.

clonker avatar Apr 19 '22 07:04 clonker

Okay, I see. Maybe putting the loop into @njit() could help? It don't see why it should be slower. Re indexing should be faster than multiplying loads of elements, I guess. Code:

from numba import njit

@njit(parallel=True)
def sort_markov_matrix(markov_matrix):
    """Takes in random markov matrix
    returns sorted markov matrix
    Args:
        markov_matrix (np.array): unsorted matrix
    Returns:
        (np.array): sorted Markov matrix
    """
    diag = np.diag(markov_matrix)
    sorting = np.argsort(diag)
    for i in range(len(diag)):
        for j in range(len(diag) - 1):
            if diag[j + 1] > diag[j]:
                markov_matrix[[j, j + 1]] = markov_matrix[[j + 1, j]]
                markov_matrix[:, [j, j + 1]] = markov_matrix[:, [j + 1, j]]
    return markov_matrix

cap-jmk avatar Apr 19 '22 20:04 cap-jmk

njit doesn't work for this function on my machine, also I don't really want to pull another dependency into deeptime. If you want to put together a python-bound c++ implementation then i'm happy to benchmark it, though. The jit performance is comparable to the python loop. In any case I think the vectorized permutation matrix implementation is a good middle ground between a lot of implementation work and harder to maintain code (c++ extension) vs. easy to write and maintain but poor performance (python loop).

clonker avatar Apr 20 '22 07:04 clonker

Good point. I will be happy to provide a c++ implementation. However, I am not sure how to link it. Do you have any resources on it? Then, let's do it with your implementation?

cap-jmk avatar Apr 20 '22 12:04 cap-jmk

I think my implementation would be a good way to move forward, yes. If you want to have a look at c++ extensions in general: We are using pybind11. The extensions are compiled, linked, and installed using CMake. Here is an example of that.

clonker avatar Apr 20 '22 13:04 clonker

Hi, any progress on this?

clonker avatar Aug 15 '22 11:08 clonker

Yes just learned some more C++ and uni politics and would have more time from now to work on it :)

cap-jmk avatar Aug 19 '22 14:08 cap-jmk

Cool, let me know if you need pointers / help!

clonker avatar Aug 19 '22 15:08 clonker