muon
muon copied to clipboard
a small BUG in mu.tl.mofa
when
import jax
import numpy as np
import anndata as ad
import muon as mu
from muon import MuData
z = jax.random.normal(key=jax.random.PRNGKey(1), shape=(100,3)) * jax.numpy.array([2, 3, 4])
w = jax.random.normal(key=jax.random.PRNGKey(1), shape=(3,200))
y = z @ w
a1 = ad.AnnData(np.array(y[:,:150]))
a2 = ad.AnnData(np.array(y[:,150:]))
mdata = MuData({"a1": a1, "a2": a2})
mdata.var_names_make_unique()
mdata.obs["group"] = jax.random.choice(key=jax.random.PRNGKey(1), a=jax.numpy.array([0, 1]), shape=(100,))
mdata.obs.group = mdata.obs.group.astype("category")
mu.tl.mofa(mdata, groups_label="group")
#########################################################
### __ __ ____ ______ ###
### | \/ |/ __ \| ____/\ _ ###
### | \ / | | | | |__ / \ _| |_ ###
### | |\/| | | | | __/ /\ \_ _| ###
### | | | | |__| | | / ____ \|_| ###
### |_| |_|\____/|_|/_/ \_\ ###
### ###
#########################################################
Loaded view='a1' group='0' with N=59 samples and D=150 features...
Loaded view='a1' group='1' with N=41 samples and D=150 features...
Loaded view='a2' group='0' with N=59 samples and D=50 features...
Loaded view='a2' group='1' with N=41 samples and D=50 features...
Model options:
- Automatic Relevance Determination prior on the factors: True
- Automatic Relevance Determination prior on the weights: True
- Spike-and-slab prior on the factors: False
- Spike-and-slab prior on the weights: True
Likelihoods:
- View 0 (a1): gaussian
- View 1 (a2): gaussian
######################################
## Training the model with seed 1 ##
######################################
Converged!
#######################
## Training finished ##
#######################
Saving model in /tmp/mofa_20220802-101431.hdf5...
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [6], in <cell line: 15>()
13 mdata.obs["group"] = jax.random.choice(key=jax.random.PRNGKey(1), a=jax.numpy.array([0, 1]), shape=(100,))
14 mdata.obs.group = mdata.obs.group.astype("category")
---> 15 mu.tl.mofa(mdata, groups_label="group")
File ~/miniconda3/envs/scvi-env/lib/python3.9/site-packages/muon/_core/tools.py:581, in mofa(data, groups_label, use_raw, use_layer, use_var, use_obs, likelihoods, n_factors, scale_views, scale_groups, center_groups, ard_weights, ard_factors, spikeslab_weights, spikeslab_factors, n_iterations, convergence_mode, gpu_mode, use_float32, smooth_covariate, smooth_warping, smooth_kwargs, save_parameters, save_data, save_metadata, seed, outfile, expectations, save_interrupted, verbose, quiet, copy)
579 else:
580 if groups_label:
--> 581 z = pd.DataFrame(z, index=zs).loc[mdata.obs.index.values].to_numpy()
582 data.obsm["X_mofa"] = z
584 # Weights
File ~/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pandas/core/frame.py:694, in DataFrame.__init__(self, data, index, columns, dtype, copy)
684 mgr = dict_to_mgr(
685 # error: Item "ndarray" of "Union[ndarray, Series, Index]" has no
686 # attribute "name"
(...)
691 typ=manager,
692 )
693 else:
--> 694 mgr = ndarray_to_mgr(
695 data,
696 index,
697 columns,
698 dtype=dtype,
699 copy=copy,
700 typ=manager,
701 )
703 # For data is list-like, or Iterable (will consume into list)
704 elif is_list_like(data):
File ~/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pandas/core/internals/construction.py:351, in ndarray_to_mgr(values, index, columns, dtype, copy, typ)
346 # _prep_ndarray ensures that values.ndim == 2 at this point
347 index, columns = _get_axes(
348 values.shape[0], values.shape[1], index=index, columns=columns
349 )
--> 351 _check_values_indices_shape_match(values, index, columns)
353 if typ == "array":
355 if issubclass(values.dtype.type, str):
File ~/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pandas/core/internals/construction.py:418, in _check_values_indices_shape_match(values, index, columns)
414 if values.shape[1] != len(columns) or values.shape[0] != len(index):
415 # Could let this raise in Block constructor, but we get a more
416 # helpful exception message this way.
417 if values.shape[0] == 0:
--> 418 raise ValueError("Empty data passed with indices specified.")
420 passed = values.shape
421 implied = (len(index), len(columns))
ValueError: Empty data passed with indices specified.
HOWEVER,
mdata.obs.group = mdata.obs.group.astype(str)
can run !
So,it is a type error.
Thanks @liujilei156231, I can reproduce this, will fix it for the next release.