Generalised EM algorithm: Implement bic() and predict() methods similar to the sklearn framework
Hello, some time ago I raised this issue 1018 regarding GMM learning on a Hypersphere manifold. It has been closed by @nguigs via the PR 1026.
Is there an example available for that new feature that I can use as a basis? I tried the following and still receive the same error (AttributeError: 'HypersphereMetric' object has no attribute 'normalization_factor_init') as before
import matplotlib.pyplot as plt
import numpy as np
import geomstats.visualization as visualization
import geomstats.backend as gs
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.expectation_maximization import RiemannianEM
from geomstats.learning.kmeans import RiemannianKMeans
sphere = Hypersphere(dim=2)
cluster = sphere.random_von_mises_fisher(kappa=20, n_samples=140)
SO3 = SpecialOrthogonal(3)
rotation1 = SO3.random_uniform()
rotation2 = SO3.random_uniform()
cluster_1 = cluster @ rotation1
cluster_2 = cluster @ rotation2
manifold = Hypersphere(dim=2)
metric = manifold.metric
data = gs.concatenate((cluster_1, cluster_2), axis=0)
kmeans = RiemannianKMeans(metric, 2, tol=1e-3)
kmeans.fit(data)
labels = kmeans.predict(data)
centroids = kmeans.centroids
EM = RiemannianEM(n_gaussians=2, metric=metric)
means, variances, mixture_coefficients = EM.fit(data=data,max_iter=100)
bic = EM.bic(data)
labels = EM.predict(data)
Maybe this is related to my outdated installation? I did
pip3 uninstall geomstats
pip3 install geomstats
which didn't help. Also I tried
git clone https://github.com/geomstats/geomstats.git
pip3 install -r requirements.txt
according to the installation instructions, which however does not seem to install geomstats. Do I need to call the setup.py? file as well and how?
Thanks
Niels
Hi Niels, sorry for the inconvenience.
We have not released a new version of geomstats on pypi yet so reinstalling with pip won't get you the update of #1026.
However cloning the master branch of the geomstats repository will, and no further installation is required, except adding geomstats to you path. You can use sys.path.append(dir) where dir is the path to the cloned repository.
Then your code works for me except the last two lines (bic and predict are not implemented).
Yes, that works for me, thanks!
Are there plans to implement bic() and predict() methods similar to the sklearn framework?
Now there is! Not sure who is available to tackle this soon. Would you be willing to try and do it, e.g. even send a code snippet over that we can integrate within our EM?
Thanks! I won't be available for the next three weeks, afterwards I can have a look at it.