geomstats icon indicating copy to clipboard operation
geomstats copied to clipboard

Generalised EM algorithm: Implement bic() and predict() methods similar to the sklearn framework

Open ndehio opened this issue 4 years ago • 4 comments

Hello, some time ago I raised this issue 1018 regarding GMM learning on a Hypersphere manifold. It has been closed by @nguigs via the PR 1026.

Is there an example available for that new feature that I can use as a basis? I tried the following and still receive the same error (AttributeError: 'HypersphereMetric' object has no attribute 'normalization_factor_init') as before

import matplotlib.pyplot as plt
import numpy as np
import geomstats.visualization as visualization
import geomstats.backend as gs
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.expectation_maximization import RiemannianEM
from geomstats.learning.kmeans import RiemannianKMeans

sphere = Hypersphere(dim=2)
cluster = sphere.random_von_mises_fisher(kappa=20, n_samples=140)
SO3 = SpecialOrthogonal(3)
rotation1 = SO3.random_uniform()
rotation2 = SO3.random_uniform()
cluster_1 =  cluster @ rotation1
cluster_2 =  cluster @ rotation2

manifold = Hypersphere(dim=2)
metric = manifold.metric
data = gs.concatenate((cluster_1, cluster_2), axis=0)

kmeans = RiemannianKMeans(metric, 2, tol=1e-3)
kmeans.fit(data)
labels = kmeans.predict(data)
centroids = kmeans.centroids

EM = RiemannianEM(n_gaussians=2, metric=metric)
means, variances, mixture_coefficients = EM.fit(data=data,max_iter=100)
bic = EM.bic(data)
labels = EM.predict(data)

Maybe this is related to my outdated installation? I did

pip3 uninstall geomstats
pip3 install geomstats

which didn't help. Also I tried

git clone https://github.com/geomstats/geomstats.git
pip3 install -r requirements.txt

according to the installation instructions, which however does not seem to install geomstats. Do I need to call the setup.py? file as well and how? Thanks Niels

ndehio avatar Jul 20 '21 09:07 ndehio

Hi Niels, sorry for the inconvenience. We have not released a new version of geomstats on pypi yet so reinstalling with pip won't get you the update of #1026. However cloning the master branch of the geomstats repository will, and no further installation is required, except adding geomstats to you path. You can use sys.path.append(dir) where dir is the path to the cloned repository.

Then your code works for me except the last two lines (bic and predict are not implemented).

nguigs avatar Jul 20 '21 12:07 nguigs

Yes, that works for me, thanks!

Are there plans to implement bic() and predict() methods similar to the sklearn framework?

ndehio avatar Jul 21 '21 08:07 ndehio

Now there is! Not sure who is available to tackle this soon. Would you be willing to try and do it, e.g. even send a code snippet over that we can integrate within our EM?

ninamiolane avatar Jul 21 '21 10:07 ninamiolane

Thanks! I won't be available for the next three weeks, afterwards I can have a look at it.

ndehio avatar Jul 21 '21 10:07 ndehio