pertpy
pertpy copied to clipboard
Streamlining Distance API
Description of feature
This is a continuation of https://github.com/theislab/pertpy/issues/405 but specific for Distance
.
TLDR: Currently, Distance
does not adhere to the API design of the rest of pertpy and I want to harmonize it. Currently, we pass a metric
to the constructor which then uses the appropriate distance function on __call__
. This comes with two issues:
- It's not consistent with the rest of the API
- https://pertpy.readthedocs.io/en/latest/usage/tools/pertpy.tools.Distance.html#pertpy.tools.Distance the options show up in a really long docstring list and if we wanted to document more, it becomes unreadable and hard to navigate
Currently we also have the
-
onesided_distances
-
pairwise
-
precompute_distances
functions.
Moving metric
into these 3 functions wouldn't really help or solve any issue. The only option I see is having functions like:
distance.compute_wasserstein(mode=Literal'onesided', 'pairwise', 'precompute'])
for all of the metrics. These would then show up in a table of functions and can be documented more easily. It would also probably correspond better with the current design.
What do you think? I'm especially interested in @yugeji, @stefanpeidli, and @tessadgreen opinion.
Seems reasonable. No immediate problems this could cause come to my mind. And by having a function per metric we could add more docs including formulas, which I agree is nice.
Two issues come to mind:
- The obvious - calling a set of distances one after another would look more ugly. For example, right now I can call
for metric in metrics:
distance = pt.tl.Distance(metric=metric)
but with the proposed change, I would call
for metric in metrics:
distance = func_dict[metric](mode='onesided')
- Where does
from_precomputed
go? In the use case right now using the same distance object above, you would have calledprecompute_distances
on an adata and then using the distance__call__
or either of.onesided(X,Y)
or.pairwise(X, Y)
would have made use of the precomputed distances:
distance = pt.tl.Distance(metric='wasserstein')
distance.precompute_distances(adata)
df = distance(adata, groupby, etc.)
In the proposed implementation, you would
distance=pt.tl.Distance.compute_wasserstein(mode='precompute')
distance(adata)
distance=pt.tl.Distance.compute_wasserstein(mode='pairwise') # using pairwise as an example, also where it makes the most sense
df = distance(adata, groupby, etc.)
In my opinion, this is considerably less readable and not intuitive. It also doesn't just apply to precompute
but also to the case in which you want to calculate any summary statistic beforehand, which is what we definitely want to do because that's a major speedup.
- Only matters if you implement it this way, but it should be just
wasserstein
and notcompute_wasserstein
.
And just to clarify, you're thinking of using it like
distance=pt.tl.Distance.compute_wasserstein(mode='pairwise')
df = distance(adata, groupby, etc.)
NOT
distance=pt.tl.Distance()
df = distance.compute_wasserstein(mode='pairwise')(adata, groupby, etc.)
Right?
Discussed a few things with @yugeji
- I need to find a way to make the looping easier
- Some distances like MMD could make use of specific parameters. Currently we could only use them as kwargs. The new design would allow us to properly document them.
- There's 3 modes that we need to support: 1 vs 1, pairwise, one-sided. While pairwise and one-sided would work as suggested above, we'd need to make 1 vs 1 first class AnnData supported. Currently these implementations are in the
__call__
and only eat numpy arrays - There'll be lots of docstring repetitions but we'll circumvent that with a docstring decorator
For future distances that do not use what is currently the standard __call__
format of (X, Y), implementing it in the new way would let you override onesided
with a distance-specific one (for example, with classifier class projection or KNN distance).
An important addition to this refactor (which would also allow classifier_cp
to be used with pairwise
) would be to make pairwise
include calls to onesided
instead of using the copy-pasted code which is happening right now (and which is also causing problems).