brian2 icon indicating copy to clipboard operation
brian2 copied to clipboard

Implementation of Markov channel models

Open jonathanoesterle opened this issue 5 years ago • 3 comments

Hi all,

I would like to implement Markov models for ion channels in Brian2 (as already discussed with @mstimberg). It seems possible without any changes to the core of Brian2. I'm trying to implement the Gillespie algorithm and succeded for a very simple 2-state-2-transitions model (see below).

    alpha
A <------> B
    beta

However, I couldn't figure out a good way to generalize it to arbitrary channels and transission matrices since (from how I understood it) input variables to cython/cpp functions must be scalars and can't be arrays. Another obvious problem of the code is that the same function get's called several times (for every channel state once per timestep), this could also be avoided by vector outputs+indexing.

So my main question is: Is there a way to use arrays in the equations?

(I thought about using a single integer to encode several small integeres, but that wouldn't be a very clean solution).

Any suggestions on how to proceed from here, are very welcome.

Thanks in advance, best regards, Jonathan

The code:

Gillespie algorithm

@implementation('cython', '''
import numpy as np
def get_X_change(X0, X1, rate_01, rate_10, dt, seed0, X_idx_return):
    
    X0 = int(X0)
    X1 = int(X1)
    X_idx_return = int(X_idx_return)
    
    
    # Set seed.
    np.random.seed(int(seed0*np.iinfo(np.int32).max))
    
    # Compute transitions.
    n_trans      = 2                            # Number of transitions.
    trans_rate   = np.array([rate_01, rate_10]) # Rate of transitions.
    trans_source = np.array([0, 1])             # Source states of transitions.
    trans_target = np.array([1, 0])             # Target states of transitions.
    
    # Set states.
    X = np.array([X0, X1])
    
    # Get change of states
    dX = np.array([0, 0])
                            
    # Sample transitions.
    t = 0
    while True:
        
        # Get transition lambdas.
        lam = np.zeros(n_trans)
        for i, Xi in enumerate(X):
            lam += trans_rate * (trans_source == i) * Xi
    
        # Get cummulative lambdas.
        lam_cum = np.cumsum(lam)
        
        # Normalize cummulative lambdas.
        lam_norm = lam_cum / lam_cum[lam_cum.size-1]
        
        if lam_cum[lam_cum.size-1] <= 0:
            break
        
        # Get transition time.
        r0 = np.random.uniform(0,1)
        trans_dt = -np.log(r0) / lam_cum[lam_cum.size-1]
        
        # Update t.
        t = t + trans_dt
        
        # Check if transition time is within time step.
        if t > dt:
            break
        else:
            pass
        
        # Get transition index.
        r1 = np.random.uniform(0,1)
        trans_idx = int(np.sum(r1 > lam_norm))
        
        # Execute transition. Update source and target X.
        dX[trans_source[trans_idx]] -= 1
        dX[trans_target[trans_idx]] += 1
        
        X[trans_source[trans_idx]] -= 1
        X[trans_target[trans_idx]] += 1
           
    return dX[X_idx_return]
    '''
)
@check_units(X0=1, X1=1, rate_01=Hz, rate_10=Hz, dt=second, seed0=1, X_idx_return=1, result=1)

def get_X_change(X0, X1, rate_01, rate_10, dt, seed0, X_idx_return):
    
    # Set seed.
    np.random.seed(int(seed0*np.iinfo(np.int32).max))
    
    # Compute transitions.
    n_trans      = 2                            # Number of transitions.
    trans_rate   = np.array([rate_01, rate_10]) # Rate of transitions.
    trans_source = np.array([0, 1])             # Source states of transitions.
    trans_target = np.array([1, 0])             # Target states of transitions.
    
    # Set states.
    X = np.array([X0, X1])
    
    # Get change of states
    dX = np.array([0, 0])
                            
    # Sample transitions.
    t = 0 * second
    while True:
        
        # Get transition lambdas.
        lam = np.zeros(n_trans)
        for i, Xi in enumerate(X):
            lam += trans_rate * (trans_source == i) * Xi
    
        # Get cummulative lambdas.
        lam_cum = np.cumsum(lam)
        
        # Normalize cummulative lambdas.
        lam_norm = lam_cum / lam_cum[lam_cum.size-1]
        
        if lam_cum[lam_cum.size-1] <= 0:
            break
        
        # Get transition time.
        r0 = np.random.uniform(0,1)
        trans_dt = -np.log(r0) / lam_cum[lam_cum.size-1] * second
        
        # Update t.
        t = t + trans_dt
        
        # Check if transition time is within time step.
        if t > dt:
            break
        else:
            pass
        
        # Get transition index.
        r1 = np.random.uniform(0,1)
        trans_idx = np.sum(r1 > lam_norm)
        
        # Execute transition. Update source and target X.
        dX[trans_source[trans_idx]] -= 1
        dX[trans_target[trans_idx]] += 1
        
        X[trans_source[trans_idx]] -= 1
        X[trans_target[trans_idx]] += 1
           
    return dX[X_idx_return]

The brian2 model

alpha = 100*Hz
beta  = 100*Hz

X_tot = 100

eqs = Equations('''
seed = rand() : 1 (constant over dt) # Necessary because get_X_change must be called several times and the result should'nt vary within a timestep

rate_01 = alpha : Hz
rate_10 = beta : Hz

dX0/dt = get_X_change(X0, X1, rate_01, rate_10, dt, seed, 0) / dt : 1
dX1/dt = get_X_change(X0, X1, rate_01, rate_10, dt, seed, 1) / dt : 1

''')

Run and record:

start_scope()

NM         = NeuronGroup(1, eqs, method='euler')
X0           = StateMonitor(NM, 'X0', record=True)
X1           = StateMonitor(NM, 'X1', record=True)

# initial states
NM.X0 = X_tot
NM.X1 = 0

run(100*ms)

jonathanoesterle avatar Jun 27 '19 12:06 jonathanoesterle

Result for alpha=beta=100:

image

jonathanoesterle avatar Jun 27 '19 12:06 jonathanoesterle

Hi. Sorry, it took me a while to get around looking into this. Unfortunately the equations as such do not allow to use arrays and other data structures, and this would be somewhat complicated since equations are expressed from the point of view of a single neuron. However, user-defined functions like yours can make use of arrays stored outside of the model, which can be useful as some global storage, and also be used to store results for later use so that you do not have to calculate them twice. For this, you have to hand over references to these arrays via @implementation's namespace argument. From within the code, you can then access the variables with _namespace<VARNAME>.

Here's an example that adapts your code so that it stores the results of the calculation for all states and the time when it last did this calculation, so that when it gets asked for the change in X1, it does not have to calculate it again:

X_changes = np.zeros(2, dtype=np.int32)
X_last_calculation = np.array([-1])  # dummy value for "never"

@implementation('cython', '''
import numpy as np
def get_X_change(X0, X1, rate_01, rate_10, t, dt, X_idx_return):
    X_idx_return = int(X_idx_return)
    if t == _namespaceX_last_calculation[0]:
        # No need to calculate things
        return _namespaceX_changes[X_idx_return]
    else:
        _namespaceX_last_calculation[0] = t

    # Compute transitions.
    n_trans      = 2                            # Number of transitions.
    trans_rate   = np.array([rate_01, rate_10]) # Rate of transitions.
    trans_source = np.array([0, 1])             # Source states of transitions.
    trans_target = np.array([1, 0])             # Target states of transitions.

    # Set states.
    X = np.array([X0, X1])

    # Initialize change of states
    _namespaceX_changes[0] = 0
    _namespaceX_changes[1] = 0

    # Sample transitions.
    _t = 0
    while True:

        # Get transition lambdas.
        lam = np.zeros(n_trans)
        for i, Xi in enumerate(X):
            lam += trans_rate * (trans_source == i) * Xi

        # Get cummulative lambdas.
        lam_cum = np.cumsum(lam)

        # Normalize cummulative lambdas.
        lam_norm = lam_cum / lam_cum[lam_cum.size-1]

        if lam_cum[lam_cum.size-1] <= 0:
            break

        # Get transition time.
        r0 = np.random.uniform(0,1)
        trans_dt = -np.log(r0) / lam_cum[lam_cum.size-1]

        # Update t.
        _t = _t + trans_dt

        # Check if transition time is within time step.
        if _t > dt:
            break
        else:
            pass

        # Get transition index.
        r1 = np.random.uniform(0,1)
        trans_idx = int(np.sum(r1 > lam_norm))

        # Execute transition. Update source and target X.
        _namespaceX_changes[trans_source[trans_idx]] -= 1
        _namespaceX_changes[trans_target[trans_idx]] += 1

        X[trans_source[trans_idx]] -= 1
        X[trans_target[trans_idx]] += 1

    return _namespaceX_changes[X_idx_return]
    ''', namespace={'X_changes': X_changes,
                    'X_last_calculation': X_last_calculation})
@check_units(X0=1, X1=1, rate_01=Hz, rate_10=Hz, t=second, dt=second, X_idx_return=1, result=1)
def get_X_change(X0, X1, rate_01, rate_10, t, dt, X_idx_return):
    raise NotImplementedError('Use Cython')

Note that you could use the same mechanism to provide the transmission rates if they are constant over a run. Is that the case for you? If not, there will be quite a number of parameters to hand over to the function each time when you have more states...

I did not yet look into making things fast (the current Cython code uses a lot of Python code), but I'm sure there's plenty of room for improvement.

A general comment: I would not describe X0 and X1 via differential equations, instead use the run_regularly operation:

eqs = Equations('''
rate_01 = alpha : Hz
rate_10 = beta : Hz
X0 : integer
X1 : integer
''')

start_scope()
NM         = NeuronGroup(1, eqs, method='euler')
NM.run_regularly('''X0 += get_X_change(X0, X1, rate_01, rate_10, t, dt, 0)
                    X1 += get_X_change(X0, X1, rate_01, rate_10, t, dt, 1)''')

mstimberg avatar Jul 04 '19 07:07 mstimberg

Thanks, this helps a lot!

Note that you could use the same mechanism to provide the transmission rates if they are constant over a run. Is that the case for you? If not, there will be quite a number of parameters to hand over to the function each time when you have more states...

The variables I named "rates" are constant over one time step, since I assume the gating variables to be almost constant within a time step. However, the effective rates named "lam" for lambda are not constant but depend on the transitions that happen within the time step, because they are dependent on the number of channels in each state. But one could definitely only update changing lambdas.

I did not yet look into making things fast (the current Cython code uses a lot of Python code), but I'm sure there's plenty of room for improvement.

I agree that there seems to be a lot of potential for improving speed. I will do that once everything is working.

A general comment: I would not describe X0 and X1 via differential equations, instead use the run_regularly operation:

Nice, much better!


Minor change I added (in case anybody else want's to use that code): X_last_calculation = np.array([-1]) should be X_last_calculation = np.array([-1]).astype(float)

jonathanoesterle avatar Jul 04 '19 08:07 jonathanoesterle