scanpy
scanpy copied to clipboard
Subsample by observations grouping
- [X] Additional function parameters / changed functionality / changed defaults?
- [ ] New analysis tool: A simple analysis tool you have been using and are missing in
sc.tools
? - [ ] New plotting function: A kind of plot you would like to seein
sc.pl
? - [ ] External tools: Do you know an existing package that should go into
sc.external.*
? - [ ] Other?
Related to scanpy.pp.subsample, it would be useful to have a subsampling tool that subsamples based on the key of an observations grouping. E.g., if I have an observation key 'MyGroup' with possible values ['A', 'B'], and there are 10,000 cells of type 'A' and 2,000 cells of type 'B' and I want only max 5,000 cells of each type, then this function would subsample 5,000 cells of type 'A' but retain all 2,000 cells of type 'B'.
Something like this should work. Note, this is not tested.
target_cells = 5000
adatas = [adata[adata.obs[cluster_key].isin(clust)] for clust in adata.obs[cluster_key].cat.categories]
for dat in adatas:
if dat.n_obs > target_cells:
sc.pp.subsample(dat, n_obs=target_cells)
adata_downsampled = adatas[0].concatenate(*adatas[1:])
Hope that helps.
Thank you @LuckyMD, it worked!
I'll reopen this cause I think it's quite relevant still and could be very straightforward to implement with sklearn resample
also, there is an entire package for subsampling strategies which is probably quite relevant: https://github.com/scikit-learn-contrib/imbalanced-learn
line here for reference: https://github.com/theislab/scanpy/blob/48cc7b38f1f31a78902a892041902cc810ddfcd3/scanpy/preprocessing/_simple.py#L857
back here reminding myself that this would be very useful feature to have...
@bio-la also expressed some interest here on MM
@giovp, did you have a particular strategy in mind for resampling?
So assuming that we are only interested in downsampling, then I'd say NearMiss
and related are straightforward and scalable (just need to compute a kmeans whcih is really fast)
also, the fact that reshuflling is performed is not in docs and should be documented. @bio-la do you plan to work on this?
then I'd say NearMiss and related are straightforward and scalable (just need to compute a kmeans whcih is really fast)
For sampling from datasets, I would want to go with either extremely straightforward or something that has been shown to work. Maybe we could start with use provided labels to downsample by?
reshuflling is performed
Reshuffling meaning that the order is changed?
Linking some previous discussion:
- https://github.com/theislab/scanpy/pull/943
- https://github.com/theislab/scanpy/pull/1382
clust
in scanpy1.8 , this works
`target_cells = 3000
adatas = [adata_train[adata_train.obs[cluster_key].isin([clust])] for clust in adata_train.obs[cluster_key].cat.categories]
for dat in adatas: if dat.n_obs > target_cells: sc.pp.subsample(dat, n_obs=target_cells, random_state=0)
adata_train_downsampled1 = adatas[0].concatenate(*adatas[1:])`
This function at least subsamples all classes in an obs column to the same number of cells. Would be straightforward to modify to what you probably think of.
def obs_key_wise_subsampling(adata, obs_key, N):
'''
Subsample each class to same cell numbers (N). Classes are given by obs_key pointing to categorical in adata.obs.
'''
counts = adata.obs[obs_key].value_counts()
# subsample indices per group defined by obs_key
indices = [np.random.choice(adata.obs_names[adata.obs[obs_key]==group], size=N, replace=False) for group in counts.index]
selection = np.hstack(np.array(indices))
return adata[selection].copy()
@stefanpeidli's code gives this error
ValueError: Cannot take a larger sample than population when 'replace=False'
If a group has less than required number observations, it shouldn't subsample.
target_cells = 1000
cluster_key = "cell_type"
grouped = adata.obs.groupby(cluster_key)
downsampled_indices = []
for _, group in grouped:
if len(group) > target_cells:
downsampled_indices.extend(group.sample(target_cells).index)
else:
downsampled_indices.extend(group.index)
adata_downsampled = adata[downsampled_indices]