scikit-learn-intelex
scikit-learn-intelex copied to clipboard
Numerical precision issues in K-means
The way inertia is computed in daal can be different from the real one which would be computed using the labels, with an arbitrarily large relative error. Below is an example of such behavior with intel sklearn:
import numpy as np
from sklearn.cluster import KMeans
X = np.array([[-1.0001],[-0.9999],[0.9999],[1.0001]], dtype=np.float32)
km = KMeans(n_clusters=2, n_init=1, algorithm='full')
km.fit(X)
km.cluster_centers_
>>> [[ 1.]
[-1.]]
km.inertia_
>>> 0.0
((X - km.cluster_centers_[km.labels_])**2).sum()
>>> 4.0013276e-08
Here the fitted and exact inertia differ in every digit !
The issue is not about the centers which are the right ones, so computing the distances with ||x||² - 2x.y + ||y||²
, which can lead to catastrophic cancellation, is fine to find the clusters but not to compute the inertia.
In daal, each time a cluster is found to be the closest from a sample, their distance, computed with the above formula, is added to the inertia. The right way would be not to update inertia incrementally but compute it using the labels and centers at the end of the iteration.
I think this is the reason of the failing sklearn test. With several inits, a same clustering with labels permuted can be found with different computed inertia. Can you confirm that the failure only occurs with lloyd algorithm (elkan uses the safe distance formula)?
@fschlimb @oleksandr-pavlyk @ogrisel
Here is what I'm getting:
In [1]: import numpy as np
...: from sklearn.cluster import KMeans
...:
...: X = np.array([[-1.0001],[-0.9999],[0.9999],[1.0001]], dtype=np.float32)
...: km = KMeans(n_clusters=2, n_init=1, algorithm='full')
...: km.fit(X)
Out[1]:
KMeans_daal4py(algorithm='full', copy_x=True, init='k-means++', max_iter=300,
n_clusters=2, n_init=1, n_jobs=None, precompute_distances='auto',
random_state=None, tol=0.0001, verbose=0)
In [2]:
In [2]: km.cluster_centers_
Out[2]:
array([[ 1.],
[-1.]], dtype=float32)
In [3]: km.inertia_
Out[3]: 8.002655e-08
In [4]: ((X - km.cluster_centers_[km.labels_])**2).sum()
Out[4]: 4.0013276e-08
Could you detail your setup please? Outputs of conda list scikit-learn
and conda list daal
would be useful.
Here is the output of conda list scikit-learn and daal
Name Version Build Channel
scikit-learn 0.19.1 np114py36_35 intel
daal 2019.0 intel_117 intel
pydaal 2019.0.0.20180713 np114py36_0 intel
After upgrading, conda list gives
Name Version Build Channel
scikit-learn 0.20.0 py36_17 intel
daal 2019.1 intel_144 intel
daal4py 0.2019.1.1 py36_0 intel
pydaal 2019.0.1.20181005 py36_1 intel
And then I get your result.
I think I confirmed your intuition about the cause for failure of test_k_means_fit_predict
for float32 with algo='full'
. Due to numerical issues of computing distances in float32, computed centers end up slightly displaced relative to each other from run to run with the same initial cluster positions in multi-threaded environment. Due to this labels get assigned differently.
The v_measure_score
of labels remain the same though, and the difference of intertias remain within the requested tolerance.
I asked DAAL team to take a look, but I think we should relax the test to check for v_measure_score rather than same labels.
This is an example I ran:
import numpy as np
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
from sklearn.utils.testing import assert_array_equal
import daal4py
from sklearn.cluster.k_means_ import _k_init
from sklearn.metrics import v_measure_score
seed = 3
n_clusters = 10
dtype = np.single
tol = 1e-8
max_iter = 300
X = make_blobs(n_samples=1000, n_features=10, centers=n_clusters,
random_state=seed)[0].astype(dtype, copy=False)
X = np.asarray(X)
n_init = 10
def get_starting_centroids():
"""Compute n_init starting values for given seed"""
sum_of_squares_X = np.square(X).sum(axis=1)
random_state = np.random.RandomState(seed)
cc_list = []
for k in range(n_init):
starting_centroids_ = _k_init(X, n_clusters, sum_of_squares_X, random_state)
cc_list.append(starting_centroids_)
return cc_list
def run_kmeans(cc_list):
kmeans_algo = daal4py.kmeans(
nClusters = n_clusters,
maxIterations = max_iter,
assignFlag = True,
accuracyThreshold = tol,
fptype = 'float',
method = 'defaultDense')
best_labels, best_inertia, best_cluster_centers = None, None, None
best_n_iter = -1
for cc_i in cc_list:
starting_centroids_ = cc_i
res = kmeans_algo.compute(X, starting_centroids_)
inertia = res.goalFunction[0,0]
if best_inertia is None or inertia < best_inertia:
best_labels = res.assignments.ravel()
best_cluster_centers = res.centroids
if n_init > 1:
best_labels = best_labels.copy()
best_cluster_centers = best_cluster_centers.copy()
best_inertia = inertia
best_n_iter = int(res.nIterations[0,0])
return best_cluster_centers, best_labels, best_inertia, best_n_iter
starting_centroids_list = get_starting_centroids()
cc_1, l_1, bi_1, n_iter_1 = run_kmeans(starting_centroids_list)
it = 0
while True:
cc_2, l_2, bi_2, n_iter_2 = run_kmeans(starting_centroids_list)
if not np.all(l_1 == l_2):
print("Failed at iteration {}".format(it))
print((bi_1, bi_2, bi_1 - bi_2))
assert n_iter_1 < max_iter
assert n_iter_2 < max_iter
print( (v_measure_score(l_1, l_2) - 1, v_measure_score(l_2, l_1) - 1) )
assert_array_equal( np.take(cc_1, l_1, axis=0), np.take(cc_2, l_2, axis=0) )
assert_array_equal(l_1, l_2)
it += 1
Running it I am getting I sometimes get only differences in labels:
(def) [08:55:13 vmlin test_tmp]$ python daal4py_fit_predict.py
Failed at iteration 3469
(9769.283, 9769.283, 0.0)
(0.0, 0.0)
Traceback (most recent call last):
File "daal4py_fit_predict.py", line 77, in <module>
assert_array_equal(l_1, l_2)
File "~/miniconda3_cb3/envs/def/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 865, in assert_array_equal
verbose=verbose, header='Arrays are not equal')
File "~/miniconda3_cb3/envs/def/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 789, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Arrays are not equal
(mismatch 90.0%)
x: array([4, 1, 7, 2, 7, 9, 5, 9, 4, 4, 9, 5, 0, 5, 6, 6, 7, 8, 1, 8, 7, 5,
8, 2, 0, 8, 9, 4, 3, 1, 9, 6, 4, 6, 4, 3, 9, 7, 3, 7, 9, 6, 2, 1,
2, 4, 1, 1, 2, 4, 8, 7, 7, 7, 5, 8, 8, 3, 6, 7, 4, 2, 6, 4, 9, 9,...
y: array([5, 7, 8, 1, 8, 6, 3, 6, 5, 5, 6, 3, 0, 3, 2, 2, 8, 9, 7, 9, 8, 3,
9, 1, 0, 9, 6, 5, 4, 7, 6, 2, 5, 2, 5, 4, 6, 8, 4, 8, 6, 2, 1, 7,
1, 5, 7, 7, 1, 5, 9, 8, 8, 8, 3, 9, 9, 4, 2, 8, 5, 1, 2, 5, 6, 6,...
and sometimes cluster centers associated with each point would be slightly off:
(def) [08:57:30 vmlin test_tmp]$ python daal4py_fit_predict.py
Failed at iteration 474
(9769.284, 9769.283, 0.0009765625)
(0.0, 0.0)
Traceback (most recent call last):
File "daal4py_fit_predict.py", line 76, in <module>
assert_array_equal( np.take(cc_1, l_1, axis=0), np.take(cc_2, l_2, axis=0) )
File "~/miniconda3_cb3/envs/def/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 865, in assert_array_equal
verbose=verbose, header='Arrays are not equal')
File "~/miniconda3_cb3/envs/def/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 789, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Arrays are not equal
(mismatch 78.0%)
x: array([-7.339982, 0.931489, -6.179488, ..., 1.052105, -4.918344,
-1.632304], dtype=float32)
y: array([-7.339981, 0.931489, -6.179488, ..., 1.052105, -4.918344,
-1.632304], dtype=float32)
In all cases v_measure_score
between pair of labels gives 1.0. (Displayed quantity is v_measure_score(labels_1, labels_2) - 1.0
).