cellrank
cellrank copied to clipboard
Add initial init/term assignment
IMPORTANT: Please search among the Pull requests before creating one.
Description
TODOs:
- [ ] update docs, get rid of TODOs in the code
- [ ] fix existing tests
- [ ] ensure that initial/terminal states cannot overlap
- [ ] tests
How has this been tested?
TODO.
Closes
closes #908
Let me know once you want me to review/test this.
I'm trying to run this to a GPCCA estimator initialized from a VelocityKernel. I computed 4 marcostates and run `g.predict()´, which gives me:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Input In [9], in <cell line: 1>()
----> 1 g.predict()
2 g.plot_terminal_states(discrete=True, legend_loc='right', s=100)
File ~/Projects/cellrank/cellrank/estimators/terminal_states/_gpcca.py:274, in GPCCA.predict(self, method, which, n_cells, alpha, stability_threshold, n_states)
271 raise RuntimeError("Compute macrostates first as `.compute_macrostates()`.")
273 # fmt: off
--> 274 if len(self._macrostates.cat.categories) == 1:
275 logg.warning(f"Found only one macrostate. Making it the single {which} state")
276 self.set_states_from_macrostates(names=None, which=which, n_cells=n_cells, params=self._create_params())
AttributeError: 'StatesHolder' object has no attribute 'cat'
Note: this wasn't actually done with a VelocityKernel, I saved the transition matrix to file and then worked with a precomputed kernel because of https://github.com/theislab/cellrank/issues/933
Hi @michalk8, can I try again?
Hi @michalk8, can I try again?
Yes, the above error should've been fixed, but most likely there are other bugs still.
Codecov Report
Merging #925 (1fc260f) into master (6150f02) will decrease coverage by
0.15%
. The diff coverage is73.27%
.
:exclamation: Current head 1fc260f differs from pull request most recent head 256b886. Consider uploading reports for the commit 256b886 to get more accurate results
Additional details and impacted files
@@ Coverage Diff @@
## master #925 +/- ##
==========================================
- Coverage 82.16% 82.01% -0.15%
==========================================
Files 58 58
Lines 8763 8796 +33
Branches 1653 1665 +12
==========================================
+ Hits 7200 7214 +14
- Misses 1030 1039 +9
- Partials 533 543 +10
Impacted Files | Coverage Δ | |
---|---|---|
cellrank/external/kernels/_statot_kernel.py | 59.37% <0.00%> (ø) |
|
cellrank/kernels/utils/_pseudotime_scheme.py | 91.54% <ø> (ø) |
|
cellrank/pl/_circular_projection.py | 77.70% <ø> (ø) |
|
cellrank/estimators/_base_estimator.py | 79.67% <50.00%> (+0.10%) |
:arrow_up: |
cellrank/estimators/mixins/_lineage_drivers.py | 76.27% <50.00%> (ø) |
|
cellrank/estimators/terminal_states/_cflare.py | 67.41% <50.00%> (-1.13%) |
:arrow_down: |
...ank/estimators/mixins/_absorption_probabilities.py | 75.79% <66.66%> (-1.46%) |
:arrow_down: |
...timators/terminal_states/_term_states_estimator.py | 75.08% <72.54%> (-8.92%) |
:arrow_down: |
cellrank/estimators/terminal_states/_gpcca.py | 78.26% <74.76%> (-0.93%) |
:arrow_down: |
cellrank/_utils/_docs.py | 92.15% <100.00%> (+0.32%) |
:arrow_up: |
... and 5 more |
Hi Mike, just going through your PR. This is beautiful work, as always. Of course, I do still have a few comments 🙃
My first comment concerns the two methods g.plot_states
and g.plot_macrostates
, for GPCCA estimator g
. I'm just not sure we really need two different plotting functions here - essentially, they do the same thing, and the behavior could be entirely controlled via the which
parameter. I think we could only keep the g.plot_macrostates
method and use
-
which='all'
to show all macrostates, mimicking the current output ofg.plot_macrostates
-
which='initial/terminal'
to show initial or terminal states, mimicking the current output ofg.plot_states
-
which=['state_1', 'state_2', ...]
to plot a custom list of macrostates -
which='initial_and_terminal'
to jointly show initial and terminal states. In that case, we would have to edit the legend from['state_1', 'state_2', ...]
currently to['initial: state_1', 'terminal: state_2', ...]
, i.e., add this label to each initial and terminal state. This is nice way to compress information.
One reason for my suggesting this is that the two methods also behave a bit differently at the moment. For example, calling g.plot_macrostates(discrete=True, legend_loc='right', s=100)
gives me Fig. 1, while g.plot_states(which='terminal', legend_loc='right', size=100)
gives me Fig. 2 below. Notice how the size parameter is set to the same value, yet in Fig. 1, this seems to only have an effect on the highlighted cells in color, while in Fig. 2, it seems to affect all cells. I like the former behavior better, e.g., just increasing the size of the cells we are highlighting, and I even think this should be done by default.
Lastly, I don't actually see the which
parameter when I inspect the docstring of the g.plot_states
method; see Fig. 3. I think it would be important to expose this crucial parameter. On the other hand, we don't need to expose the time_key
or the mode
; this way of plotting makes sense for fate probabilities, but not really for macrostate memberships, I think. We could remove these two options, I think.
Fig. 1
Fig. 2
Fig. 3

My next comment concerns the g.predict
method used to classify the macrostates. While g.predict(which='terminal')
works fine and gives me the expected output, g.predict(which='initial')
does not work because it gives me the exact same states. In this PR, we don't want to change how we used to compute initial states via outliers in the coarse-grained stationary distribution; we only want to change how we interact with these functionalities. Fig. 1 shows the output I get from running
g.predict(which='initial')
g.plot_states(which='initial', legend_loc='right', s=100)
So clearly these are the terminal states, and not the initial states. Interestingly, only the highlighted cells are colored as in the macrostates plotting function, see my comment above.
The ways in which we identify initial and terminal states differ, and depending on the which
parameter, we have to change method defaults. I think this leads to confusion when we do it via a single method, and I would suggest having two methods:
-
g.predict_terminal_states
: pretty much with the current parameters and defaults ofg.predict(which='terminal
) -
g.predict_initial_states
: similar to what we used to have ing._compute_initial_states()
; which seems to have been removed in this PR.
Fig. 1

Before classifying macrostates as initial or terminal, the estimator prints an empty field for terminal states, but not for initial. I would suggest making this consistent, either show for both initial and terminal or for neither.

I would strongly vote for removing the g.compute_states()
alisas, it just adds confusion. While this method sounds super similar to g.compute_macrostates
, it does something really different.
The name is just not very suggestive for what this method actually does. If you want to have an alias for g.predict
, then we need to talk about a more informative name for this alias method - or for the two alias methods, see my related comment above.
Before classifying macrostates as initial or terminal, the estimator prints an empty field for terminal states, but not for initial. I would suggest making this consistent, either show for both initial and terminal or for neither.
![]()
Even once I predict initial states, this still only shows terminal states. I would add initial states to this representation - we should treat both equally.
I would also vote for writing macrostates to the underlying AnnData's .obs
field, same as for initial and terminal states. I have a usecase in which I want to visualize the distribution of marker genes over macrostates - I can manually write the macrostates to AnnData but then I still have the wrong color assignment, it would be much more convenient to have this written automatically upon computation.
Currently, I write using
adata.obs['macrostates'] = g2.macrostates
adata.uns['macrostates_colors'] = g2.macrostates_memberships.colors
and then I plot marker distribution using sc.pl.violin(adata, keys='Bicc1', groupby='macrostates', rotation=90)
. This is useful to identify a particular state that is highest/lowest in a certain marker gene, which might be helpful to decide which state should be initial/terminal.
To get a feeling for how I plan to use this PR in the new initial & terminal states tutorial, please check this related PR in the notebooks repo: https://github.com/theislab/cellrank_notebooks/pull/56
Discussion outcome
- [x] remove alias
g.compute_states
- [x] unify
g.plot_states
andg.plot_macrostates
- [x] introduce
g.predict_terminal_states
andg.predict_initial_states
, where the latter uses the coarse-grained stationary distribution (check old code) - [x] no overlap between initial and terminal states (Warning and rename, or throw an error)
- [x] only print initial and terminal states (not macrostates)
- [x] writing macrostates to AnnData
Hi Mike, happy new year! Did you have time to implement some of the changes already?
Thanks @michalk8, here my comments:
- As suggested before, I would prefer to have a
g.plot_macrostates
, rather thang.plot_states
method. I think this makes more sense, as everything we are dealing with are macrostates (initial macrostates, terminal macrostates, etc. ). Is there a reason you choose to call the methodg.plot_states
? It's quite confusing now, we have ag.compute_macrostates
method, but then ag.plot_states
method. Also, internally, we refer to stuff asmacrostates
. - Plotting does not work yet, calling
g.compute_macrostates(cluster_key='clusters', n_states=[4, 12], method='brandts')
followed byg.plot_states()
gives Error 1 below -
g.plot_states
does not have awhich
parameter as we discussed above
I think it makes more sense to have another meeting to discuss the anticipated workflow before we continue.
Error 1
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In [19], line 1
----> 1 g.plot_states()
File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:445, in register_plotter.<locals>.wrapper(wrapped, instance, args, kwargs)
442 else:
443 raise TypeError(f"Unable to plot a value of type `{type(obj)}`.")
--> 445 return wrapped(
446 *args,
447 _data=data,
448 _colors=colors,
449 _title=f"{attr} states",
450 discrete=discrete,
451 **kwargs,
452 )
File ~/miniforge3/envs/py39_arm/lib/python3.9/site-packages/scanpy/_utils/__init__.py:119, in deprecated_arg_names.<locals>.decorator.<locals>.func_wrapper(*args, **kwargs)
117 # reset filter
118 warnings.simplefilter('default', DeprecationWarning)
--> 119 return func(*args, **kwargs)
File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:366, in _plot_dispatcher(self, states, color, discrete, mode, time_key, same_plot, title, cmap, **kwargs)
336 """
337 Plot continuous or categorical observations in an embedding or along pseudotime.
338
(...)
363 %(just_plots)s
364 """
365 if discrete:
--> 366 return _plot_discrete(
367 self,
368 states=states,
369 color=color,
370 same_plot=same_plot,
371 title=title,
372 cmap=cmap,
373 **kwargs,
374 )
376 return _plot_continuous(
377 self,
378 states=states,
(...)
385 **kwargs,
386 )
File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:166, in _plot_discrete(self, _data, _colors, _title, states, color, title, same_plot, cmap, **kwargs)
153 def _plot_discrete(
154 self: PlotterProtocol,
155 _data: pd.Series,
(...)
163 **kwargs: Any,
164 ) -> None:
165 if not isinstance(_data, pd.Series):
--> 166 raise TypeError(
167 f"Expected `data` to be of type `pandas.Series`, found `{type(_data).__name__}`."
168 )
169 if not is_categorical_dtype(_data):
170 raise TypeError(
171 f"Expected `data` to be `categorical`, found `{infer_dtype(_data)}`."
172 )
TypeError: Expected `data` to be of type `pandas.Series`, found `NoneType`.
Thanks @michalk8, here my comments:
* As suggested before, I would prefer to have a `g.plot_macrostates`, rather than `g.plot_states` method. I think this makes more sense, as everything we are dealing with are macrostates (initial macrostates, terminal macrostates, etc. ). Is there a reason you choose to call the method `g.plot_states`? It's quite confusing now, we have a `g.compute_macrostates` method, but then a `g.plot_states` method. Also, internally, we refer to stuff as `macrostates`. * Plotting does not work yet, calling `g.compute_macrostates(cluster_key='clusters', n_states=[4, 12], method='brandts')`followed by `g.plot_states()` gives Error 1 below * `g.plot_states` does not have a `which` parameter as we discussed above
I think it makes more sense to have another meeting to discuss the anticipated workflow before we continue.
Error 1
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In [19], line 1 ----> 1 g.plot_states() File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:445, in register_plotter.<locals>.wrapper(wrapped, instance, args, kwargs) 442 else: 443 raise TypeError(f"Unable to plot a value of type `{type(obj)}`.") --> 445 return wrapped( 446 *args, 447 _data=data, 448 _colors=colors, 449 _title=f"{attr} states", 450 discrete=discrete, 451 **kwargs, 452 ) File ~/miniforge3/envs/py39_arm/lib/python3.9/site-packages/scanpy/_utils/__init__.py:119, in deprecated_arg_names.<locals>.decorator.<locals>.func_wrapper(*args, **kwargs) 117 # reset filter 118 warnings.simplefilter('default', DeprecationWarning) --> 119 return func(*args, **kwargs) File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:366, in _plot_dispatcher(self, states, color, discrete, mode, time_key, same_plot, title, cmap, **kwargs) 336 """ 337 Plot continuous or categorical observations in an embedding or along pseudotime. 338 (...) 363 %(just_plots)s 364 """ 365 if discrete: --> 366 return _plot_discrete( 367 self, 368 states=states, 369 color=color, 370 same_plot=same_plot, 371 title=title, 372 cmap=cmap, 373 **kwargs, 374 ) 376 return _plot_continuous( 377 self, 378 states=states, (...) 385 **kwargs, 386 ) File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:166, in _plot_discrete(self, _data, _colors, _title, states, color, title, same_plot, cmap, **kwargs) 153 def _plot_discrete( 154 self: PlotterProtocol, 155 _data: pd.Series, (...) 163 **kwargs: Any, 164 ) -> None: 165 if not isinstance(_data, pd.Series): --> 166 raise TypeError( 167 f"Expected `data` to be of type `pandas.Series`, found `{type(_data).__name__}`." 168 ) 169 if not is_categorical_dtype(_data): 170 raise TypeError( 171 f"Expected `data` to be `categorical`, found `{infer_dtype(_data)}`." 172 ) TypeError: Expected `data` to be of type `pandas.Series`, found `NoneType`.
Updated to reflect all the changes as mentioned above. However, I removed the default which
, because I think user should always supply it.
Oh, exciting, thanks for the updates! I will check this out soon! Looking fwd. to this.
Thanks, Mike, here are a number of comments (we're alomst there!)
- in
g.plot_macrostates
, please changewhich="macro"
towhich="all"
to show all macrostates. I think this makes more sense syntactically - you're calling a method that has "macrostates" in its name, so I don't think we should have a parameter that re-iterates this. - Please remove the
g.predict()
alias for theg.predict_terminal_states()
. It is unintuitive why that alias points tog.predict_terminal_states()
and notg.predict_initial_states()
as we want to treat them equally in CR2
I'll post these now and add a few more later
There is a bit of an inconsistency now between automatic and manual identification of initial and terminal states:
- In automatic mode, we have two methods,
g.predict_terminal_states
andg.predict_initial_states
- In manual mode, we only have one method,
g.set_states_from_macrostates
, and we control the behavior via thewhich
parameter
You could argue that the second mode of action is a bit similar to macrostate plotting. However, when we plot macrostates, the method is called g.plot_macrostates
, which I much prefer over the old g.plot_states
. In general, I would avoid the term states
and restrict our vocabulary to macrostates
, terminal_states
and initial_states
, to keep things simple.
So I suggest having g.set_terminal_states_from_macrostates
and g.set_initial_states_from_macrostates
for consistency. What is your opinion here @michalk8 ?
in g.plot_macrostates, please change which="macro" to which="all" to show all macrostates. I think this makes more sense syntactically - you're calling a method that has "macrostates" in its name, so I don't think we should have a parameter that re-iterates this.
Ok!
Please remove the g.predict() alias for the
predict
cannot be removed since it's an abstract method of the base estimator (alongside fit
).
So I suggest having g.set_terminal_states_from_macrostatesand g.set_initial_states_from_macrostates for consistency. What is your opinion here @michalk8 ?
I think the proposed names are too long.
in g.plot_macrostates, please change which="macro" to which="all" to show all macrostates. I think this makes more sense syntactically - you're calling a method that has "macrostates" in its name, so I don't think we should have a parameter that re-iterates this.
Ok!
Please remove the g.predict() alias for the
predict
cannot be removed since it's an abstract method of the base estimator (alongsidefit
).
That's an issue! It's very arbitrary what predict points to! We should talk about whether we really want our estimators to follow this fit-predict mechanism - I have the feeling this is a bit forced at the moment and might not actually work for us.
So I suggest having g.set_terminal_states_from_macrostatesand g.set_initial_states_from_macrostates for consistency. What is your opinion here @michalk8 ?
I think the proposed names are too long.
Note to self: the coarse-grained stationary distribution changes compared to the tutorial version on master, (the kernel and estimators one) I think this has something to do with the way in which I define the VelocityKernel.
Update: I verified this, running the current PR version of CellRank on the kernels and estimators tutorial gives the old coarse stationary distribution. So these changes must be due to a slightly different setup of the Velocity kernel.
Previous version of the velocity kernel
12 macrostates
We correctly identify the Ngn3 low EP 1 state as initial:
15 macrostates
We correctly identify the Ngn3 low EP 1 state as initial:
Ok, in summary:
- We agreed on the first point about "macro" -> "all" (and I think this would be a good default value)
- We can further discuss the
predict
alias tomorrow, as well as the naming/mechanism for semi-manual macrostate finding - One more feature that would be fantastic to add to the
plot_macrostates
method is a modewhich="initial_and_terminal"
, which shows both initial and terminal states. To tell them apart, the names should be modified in the legend such that they show up asinitial: macrostate_1, terminal: macrostate_2
etc. I suggested this above already, what do you think? I think this would be a very useful aggregate plotting mode.
Discussion outcome:
- Let's keep the predict alias for
predict_terminal_states
- Rename
which="macro"
towhich="all"
- Unify
set_states_from_macrostates
with theset_states
method into two methods:set_initial_states
andset_terminal_states
, which can either set terminal/initial states from macrostates or purely manually from pandas Series or similar, depending on the input.
@WeilerP, this PR introduces quite a few changes to the way we compute initial and terminal states - not the actual methods we use under the hood, only how we call the corresponding methods etc. Just to let you know that there will be some breaking changes here. I can guide you through this in more detail next week.
Uh, I'm excited to check this out @michalk8!