cellrank icon indicating copy to clipboard operation
cellrank copied to clipboard

Add initial init/term assignment

Open michalk8 opened this issue 1 year ago • 5 comments

IMPORTANT: Please search among the Pull requests before creating one.

Description

TODOs:

  • [ ] update docs, get rid of TODOs in the code
  • [ ] fix existing tests
  • [ ] ensure that initial/terminal states cannot overlap
  • [ ] tests

How has this been tested?

TODO.

Closes

closes #908

michalk8 avatar Jul 10 '22 14:07 michalk8

Let me know once you want me to review/test this.

Marius1311 avatar Jul 11 '22 14:07 Marius1311

I'm trying to run this to a GPCCA estimator initialized from a VelocityKernel. I computed 4 marcostates and run `g.predict()´, which gives me:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Input In [9], in <cell line: 1>()
----> 1 g.predict()
      2 g.plot_terminal_states(discrete=True, legend_loc='right', s=100)

File ~/Projects/cellrank/cellrank/estimators/terminal_states/_gpcca.py:274, in GPCCA.predict(self, method, which, n_cells, alpha, stability_threshold, n_states)
    271     raise RuntimeError("Compute macrostates first as `.compute_macrostates()`.")
    273 # fmt: off
--> 274 if len(self._macrostates.cat.categories) == 1:
    275     logg.warning(f"Found only one macrostate. Making it the single {which} state")
    276     self.set_states_from_macrostates(names=None, which=which, n_cells=n_cells, params=self._create_params())

AttributeError: 'StatesHolder' object has no attribute 'cat'

Marius1311 avatar Jul 28 '22 08:07 Marius1311

Note: this wasn't actually done with a VelocityKernel, I saved the transition matrix to file and then worked with a precomputed kernel because of https://github.com/theislab/cellrank/issues/933

Marius1311 avatar Jul 28 '22 08:07 Marius1311

Hi @michalk8, can I try again?

Marius1311 avatar Aug 08 '22 10:08 Marius1311

Hi @michalk8, can I try again?

Yes, the above error should've been fixed, but most likely there are other bugs still.

michalk8 avatar Aug 11 '22 11:08 michalk8

Codecov Report

Merging #925 (1fc260f) into master (6150f02) will decrease coverage by 0.15%. The diff coverage is 73.27%.

:exclamation: Current head 1fc260f differs from pull request most recent head 256b886. Consider uploading reports for the commit 256b886 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #925      +/-   ##
==========================================
- Coverage   82.16%   82.01%   -0.15%     
==========================================
  Files          58       58              
  Lines        8763     8796      +33     
  Branches     1653     1665      +12     
==========================================
+ Hits         7200     7214      +14     
- Misses       1030     1039       +9     
- Partials      533      543      +10     
Impacted Files Coverage Δ
cellrank/external/kernels/_statot_kernel.py 59.37% <0.00%> (ø)
cellrank/kernels/utils/_pseudotime_scheme.py 91.54% <ø> (ø)
cellrank/pl/_circular_projection.py 77.70% <ø> (ø)
cellrank/estimators/_base_estimator.py 79.67% <50.00%> (+0.10%) :arrow_up:
cellrank/estimators/mixins/_lineage_drivers.py 76.27% <50.00%> (ø)
cellrank/estimators/terminal_states/_cflare.py 67.41% <50.00%> (-1.13%) :arrow_down:
...ank/estimators/mixins/_absorption_probabilities.py 75.79% <66.66%> (-1.46%) :arrow_down:
...timators/terminal_states/_term_states_estimator.py 75.08% <72.54%> (-8.92%) :arrow_down:
cellrank/estimators/terminal_states/_gpcca.py 78.26% <74.76%> (-0.93%) :arrow_down:
cellrank/_utils/_docs.py 92.15% <100.00%> (+0.32%) :arrow_up:
... and 5 more

codecov[bot] avatar Dec 06 '22 20:12 codecov[bot]

Hi Mike, just going through your PR. This is beautiful work, as always. Of course, I do still have a few comments 🙃

My first comment concerns the two methods g.plot_states and g.plot_macrostates, for GPCCA estimator g. I'm just not sure we really need two different plotting functions here - essentially, they do the same thing, and the behavior could be entirely controlled via the which parameter. I think we could only keep the g.plot_macrostates method and use

  • which='all' to show all macrostates, mimicking the current output of g.plot_macrostates
  • which='initial/terminal' to show initial or terminal states, mimicking the current output of g.plot_states
  • which=['state_1', 'state_2', ...] to plot a custom list of macrostates
  • which='initial_and_terminal' to jointly show initial and terminal states. In that case, we would have to edit the legend from ['state_1', 'state_2', ...] currently to ['initial: state_1', 'terminal: state_2', ...] , i.e., add this label to each initial and terminal state. This is nice way to compress information.

One reason for my suggesting this is that the two methods also behave a bit differently at the moment. For example, calling g.plot_macrostates(discrete=True, legend_loc='right', s=100) gives me Fig. 1, while g.plot_states(which='terminal', legend_loc='right', size=100) gives me Fig. 2 below. Notice how the size parameter is set to the same value, yet in Fig. 1, this seems to only have an effect on the highlighted cells in color, while in Fig. 2, it seems to affect all cells. I like the former behavior better, e.g., just increasing the size of the cells we are highlighting, and I even think this should be done by default.

Lastly, I don't actually see the which parameter when I inspect the docstring of the g.plot_states method; see Fig. 3. I think it would be important to expose this crucial parameter. On the other hand, we don't need to expose the time_keyor the mode; this way of plotting makes sense for fate probabilities, but not really for macrostate memberships, I think. We could remove these two options, I think.

Fig. 1

image

Fig. 2

image

Fig. 3

Screenshot 2022-12-18 at 13 08 07

Marius1311 avatar Dec 18 '22 12:12 Marius1311

My next comment concerns the g.predict method used to classify the macrostates. While g.predict(which='terminal') works fine and gives me the expected output, g.predict(which='initial') does not work because it gives me the exact same states. In this PR, we don't want to change how we used to compute initial states via outliers in the coarse-grained stationary distribution; we only want to change how we interact with these functionalities. Fig. 1 shows the output I get from running

g.predict(which='initial')
g.plot_states(which='initial', legend_loc='right', s=100)

So clearly these are the terminal states, and not the initial states. Interestingly, only the highlighted cells are colored as in the macrostates plotting function, see my comment above.

The ways in which we identify initial and terminal states differ, and depending on the which parameter, we have to change method defaults. I think this leads to confusion when we do it via a single method, and I would suggest having two methods:

  • g.predict_terminal_states: pretty much with the current parameters and defaults of g.predict(which='terminal)
  • g.predict_initial_states: similar to what we used to have in g._compute_initial_states(); which seems to have been removed in this PR.

Fig. 1

Screenshot 2022-12-18 at 13 16 27

Marius1311 avatar Dec 18 '22 12:12 Marius1311

Before classifying macrostates as initial or terminal, the estimator prints an empty field for terminal states, but not for initial. I would suggest making this consistent, either show for both initial and terminal or for neither.

Screenshot 2022-12-18 at 14 44 12

Marius1311 avatar Dec 18 '22 13:12 Marius1311

I would strongly vote for removing the g.compute_states() alisas, it just adds confusion. While this method sounds super similar to g.compute_macrostates, it does something really different.

The name is just not very suggestive for what this method actually does. If you want to have an alias for g.predict, then we need to talk about a more informative name for this alias method - or for the two alias methods, see my related comment above.

Marius1311 avatar Dec 18 '22 13:12 Marius1311

Before classifying macrostates as initial or terminal, the estimator prints an empty field for terminal states, but not for initial. I would suggest making this consistent, either show for both initial and terminal or for neither.

Screenshot 2022-12-18 at 14 44 12

Even once I predict initial states, this still only shows terminal states. I would add initial states to this representation - we should treat both equally.

Marius1311 avatar Dec 18 '22 14:12 Marius1311

I would also vote for writing macrostates to the underlying AnnData's .obs field, same as for initial and terminal states. I have a usecase in which I want to visualize the distribution of marker genes over macrostates - I can manually write the macrostates to AnnData but then I still have the wrong color assignment, it would be much more convenient to have this written automatically upon computation.

Currently, I write using

adata.obs['macrostates'] = g2.macrostates
adata.uns['macrostates_colors'] = g2.macrostates_memberships.colors

and then I plot marker distribution using sc.pl.violin(adata, keys='Bicc1', groupby='macrostates', rotation=90). This is useful to identify a particular state that is highest/lowest in a certain marker gene, which might be helpful to decide which state should be initial/terminal.

Marius1311 avatar Dec 18 '22 14:12 Marius1311

To get a feeling for how I plan to use this PR in the new initial & terminal states tutorial, please check this related PR in the notebooks repo: https://github.com/theislab/cellrank_notebooks/pull/56

Marius1311 avatar Dec 18 '22 15:12 Marius1311

Discussion outcome

  • [x] remove alias g.compute_states
  • [x] unify g.plot_statesand g.plot_macrostates
  • [x] introduce g.predict_terminal_statesand g.predict_initial_states, where the latter uses the coarse-grained stationary distribution (check old code)
  • [x] no overlap between initial and terminal states (Warning and rename, or throw an error)
  • [x] only print initial and terminal states (not macrostates)
  • [x] writing macrostates to AnnData

Marius1311 avatar Dec 19 '22 15:12 Marius1311

Hi Mike, happy new year! Did you have time to implement some of the changes already?

Marius1311 avatar Jan 02 '23 13:01 Marius1311

Thanks @michalk8, here my comments:

  • As suggested before, I would prefer to have a g.plot_macrostates, rather than g.plot_states method. I think this makes more sense, as everything we are dealing with are macrostates (initial macrostates, terminal macrostates, etc. ). Is there a reason you choose to call the method g.plot_states? It's quite confusing now, we have a g.compute_macrostates method, but then a g.plot_states method. Also, internally, we refer to stuff as macrostates.
  • Plotting does not work yet, calling g.compute_macrostates(cluster_key='clusters', n_states=[4, 12], method='brandts')followed by g.plot_states() gives Error 1 below
  • g.plot_states does not have a which parameter as we discussed above

I think it makes more sense to have another meeting to discuss the anticipated workflow before we continue.

Error 1

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [19], line 1
----> 1 g.plot_states()

File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:445, in register_plotter.<locals>.wrapper(wrapped, instance, args, kwargs)
    442 else:
    443     raise TypeError(f"Unable to plot a value of type `{type(obj)}`.")
--> 445 return wrapped(
    446     *args,
    447     _data=data,
    448     _colors=colors,
    449     _title=f"{attr} states",
    450     discrete=discrete,
    451     **kwargs,
    452 )

File ~/miniforge3/envs/py39_arm/lib/python3.9/site-packages/scanpy/_utils/__init__.py:119, in deprecated_arg_names.<locals>.decorator.<locals>.func_wrapper(*args, **kwargs)
    117 # reset filter
    118 warnings.simplefilter('default', DeprecationWarning)
--> 119 return func(*args, **kwargs)

File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:366, in _plot_dispatcher(self, states, color, discrete, mode, time_key, same_plot, title, cmap, **kwargs)
    336 """
    337 Plot continuous or categorical observations in an embedding or along pseudotime.
    338 
   (...)
    363 %(just_plots)s
    364 """
    365 if discrete:
--> 366     return _plot_discrete(
    367         self,
    368         states=states,
    369         color=color,
    370         same_plot=same_plot,
    371         title=title,
    372         cmap=cmap,
    373         **kwargs,
    374     )
    376 return _plot_continuous(
    377     self,
    378     states=states,
   (...)
    385     **kwargs,
    386 )

File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:166, in _plot_discrete(self, _data, _colors, _title, states, color, title, same_plot, cmap, **kwargs)
    153 def _plot_discrete(
    154     self: PlotterProtocol,
    155     _data: pd.Series,
   (...)
    163     **kwargs: Any,
    164 ) -> None:
    165     if not isinstance(_data, pd.Series):
--> 166         raise TypeError(
    167             f"Expected `data` to be of type `pandas.Series`, found `{type(_data).__name__}`."
    168         )
    169     if not is_categorical_dtype(_data):
    170         raise TypeError(
    171             f"Expected `data` to be `categorical`, found `{infer_dtype(_data)}`."
    172         )

TypeError: Expected `data` to be of type `pandas.Series`, found `NoneType`.

Marius1311 avatar Feb 13 '23 12:02 Marius1311

Thanks @michalk8, here my comments:

* As suggested before, I would prefer to have a `g.plot_macrostates`, rather than `g.plot_states` method. I think this makes more sense, as everything we are dealing with are macrostates (initial macrostates, terminal macrostates, etc. ). Is there a reason you choose to call the method `g.plot_states`? It's quite confusing now, we have a `g.compute_macrostates` method, but then a `g.plot_states` method. Also, internally, we refer to stuff as `macrostates`.

* Plotting does not work yet, calling `g.compute_macrostates(cluster_key='clusters', n_states=[4, 12], method='brandts')`followed by `g.plot_states()` gives Error 1 below

* `g.plot_states` does not have a `which` parameter as we discussed above

I think it makes more sense to have another meeting to discuss the anticipated workflow before we continue.

Error 1

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [19], line 1
----> 1 g.plot_states()

File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:445, in register_plotter.<locals>.wrapper(wrapped, instance, args, kwargs)
    442 else:
    443     raise TypeError(f"Unable to plot a value of type `{type(obj)}`.")
--> 445 return wrapped(
    446     *args,
    447     _data=data,
    448     _colors=colors,
    449     _title=f"{attr} states",
    450     discrete=discrete,
    451     **kwargs,
    452 )

File ~/miniforge3/envs/py39_arm/lib/python3.9/site-packages/scanpy/_utils/__init__.py:119, in deprecated_arg_names.<locals>.decorator.<locals>.func_wrapper(*args, **kwargs)
    117 # reset filter
    118 warnings.simplefilter('default', DeprecationWarning)
--> 119 return func(*args, **kwargs)

File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:366, in _plot_dispatcher(self, states, color, discrete, mode, time_key, same_plot, title, cmap, **kwargs)
    336 """
    337 Plot continuous or categorical observations in an embedding or along pseudotime.
    338 
   (...)
    363 %(just_plots)s
    364 """
    365 if discrete:
--> 366     return _plot_discrete(
    367         self,
    368         states=states,
    369         color=color,
    370         same_plot=same_plot,
    371         title=title,
    372         cmap=cmap,
    373         **kwargs,
    374     )
    376 return _plot_continuous(
    377     self,
    378     states=states,
   (...)
    385     **kwargs,
    386 )

File ~/Projects/cellrank/cellrank/estimators/mixins/_utils.py:166, in _plot_discrete(self, _data, _colors, _title, states, color, title, same_plot, cmap, **kwargs)
    153 def _plot_discrete(
    154     self: PlotterProtocol,
    155     _data: pd.Series,
   (...)
    163     **kwargs: Any,
    164 ) -> None:
    165     if not isinstance(_data, pd.Series):
--> 166         raise TypeError(
    167             f"Expected `data` to be of type `pandas.Series`, found `{type(_data).__name__}`."
    168         )
    169     if not is_categorical_dtype(_data):
    170         raise TypeError(
    171             f"Expected `data` to be `categorical`, found `{infer_dtype(_data)}`."
    172         )

TypeError: Expected `data` to be of type `pandas.Series`, found `NoneType`.

Updated to reflect all the changes as mentioned above. However, I removed the default which, because I think user should always supply it.

michalk8 avatar Feb 21 '23 22:02 michalk8

Oh, exciting, thanks for the updates! I will check this out soon! Looking fwd. to this.

Marius1311 avatar Feb 22 '23 09:02 Marius1311

Thanks, Mike, here are a number of comments (we're alomst there!)

  • in g.plot_macrostates, please change which="macro" to which="all" to show all macrostates. I think this makes more sense syntactically - you're calling a method that has "macrostates" in its name, so I don't think we should have a parameter that re-iterates this.
  • Please remove the g.predict() alias for the g.predict_terminal_states(). It is unintuitive why that alias points to g.predict_terminal_states() and not g.predict_initial_states() as we want to treat them equally in CR2

I'll post these now and add a few more later

Marius1311 avatar Feb 23 '23 11:02 Marius1311

There is a bit of an inconsistency now between automatic and manual identification of initial and terminal states:

  • In automatic mode, we have two methods, g.predict_terminal_states and g.predict_initial_states
  • In manual mode, we only have one method, g.set_states_from_macrostates, and we control the behavior via the which parameter

You could argue that the second mode of action is a bit similar to macrostate plotting. However, when we plot macrostates, the method is called g.plot_macrostates, which I much prefer over the old g.plot_states. In general, I would avoid the term states and restrict our vocabulary to macrostates, terminal_statesand initial_states, to keep things simple.

So I suggest having g.set_terminal_states_from_macrostatesand g.set_initial_states_from_macrostates for consistency. What is your opinion here @michalk8 ?

Marius1311 avatar Feb 23 '23 12:02 Marius1311

in g.plot_macrostates, please change which="macro" to which="all" to show all macrostates. I think this makes more sense syntactically - you're calling a method that has "macrostates" in its name, so I don't think we should have a parameter that re-iterates this.

Ok!

Please remove the g.predict() alias for the

predict cannot be removed since it's an abstract method of the base estimator (alongside fit).

So I suggest having g.set_terminal_states_from_macrostatesand g.set_initial_states_from_macrostates for consistency. What is your opinion here @michalk8 ?

I think the proposed names are too long.

michalk8 avatar Feb 23 '23 13:02 michalk8

in g.plot_macrostates, please change which="macro" to which="all" to show all macrostates. I think this makes more sense syntactically - you're calling a method that has "macrostates" in its name, so I don't think we should have a parameter that re-iterates this.

Ok!

Please remove the g.predict() alias for the

predict cannot be removed since it's an abstract method of the base estimator (alongside fit).

That's an issue! It's very arbitrary what predict points to! We should talk about whether we really want our estimators to follow this fit-predict mechanism - I have the feeling this is a bit forced at the moment and might not actually work for us.

So I suggest having g.set_terminal_states_from_macrostatesand g.set_initial_states_from_macrostates for consistency. What is your opinion here @michalk8 ?

I think the proposed names are too long.

Marius1311 avatar Feb 23 '23 13:02 Marius1311

Note to self: the coarse-grained stationary distribution changes compared to the tutorial version on master, (the kernel and estimators one) I think this has something to do with the way in which I define the VelocityKernel.

Update: I verified this, running the current PR version of CellRank on the kernels and estimators tutorial gives the old coarse stationary distribution. So these changes must be due to a slightly different setup of the Velocity kernel.

Previous version of the velocity kernel

12 macrostates We correctly identify the Ngn3 low EP 1 state as initial: Screenshot 2023-02-23 at 14 45 13

15 macrostates We correctly identify the Ngn3 low EP 1 state as initial: Screenshot 2023-02-23 at 14 46 18

Marius1311 avatar Feb 23 '23 14:02 Marius1311

Ok, in summary:

  • We agreed on the first point about "macro" -> "all" (and I think this would be a good default value)
  • We can further discuss the predict alias tomorrow, as well as the naming/mechanism for semi-manual macrostate finding
  • One more feature that would be fantastic to add to the plot_macrostates method is a mode which="initial_and_terminal", which shows both initial and terminal states. To tell them apart, the names should be modified in the legend such that they show up as initial: macrostate_1, terminal: macrostate_2 etc. I suggested this above already, what do you think? I think this would be a very useful aggregate plotting mode.

Marius1311 avatar Feb 23 '23 14:02 Marius1311

Discussion outcome:

  • Let's keep the predict alias for predict_terminal_states
  • Rename which="macro" to which="all"
  • Unify set_states_from_macrostates with the set_states method into two methods: set_initial_states and set_terminal_states, which can either set terminal/initial states from macrostates or purely manually from pandas Series or similar, depending on the input.

Marius1311 avatar Feb 24 '23 09:02 Marius1311

@WeilerP, this PR introduces quite a few changes to the way we compute initial and terminal states - not the actual methods we use under the hood, only how we call the corresponding methods etc. Just to let you know that there will be some breaking changes here. I can guide you through this in more detail next week.

Marius1311 avatar Feb 24 '23 09:02 Marius1311

Uh, I'm excited to check this out @michalk8!

Marius1311 avatar Mar 21 '23 11:03 Marius1311