scikit-learn-intelex icon indicating copy to clipboard operation
scikit-learn-intelex copied to clipboard

Numerical precision issues in K-means

Open jeremiedbb opened this issue 6 years ago • 4 comments

The way inertia is computed in daal can be different from the real one which would be computed using the labels, with an arbitrarily large relative error. Below is an example of such behavior with intel sklearn:

import numpy as np
from sklearn.cluster import KMeans

X = np.array([[-1.0001],[-0.9999],[0.9999],[1.0001]], dtype=np.float32)
km = KMeans(n_clusters=2, n_init=1, algorithm='full')
km.fit(X)

km.cluster_centers_
>>> [[ 1.]
     [-1.]]

km.inertia_
>>> 0.0

((X - km.cluster_centers_[km.labels_])**2).sum()
>>> 4.0013276e-08

Here the fitted and exact inertia differ in every digit !

The issue is not about the centers which are the right ones, so computing the distances with ||x||² - 2x.y + ||y||², which can lead to catastrophic cancellation, is fine to find the clusters but not to compute the inertia.

In daal, each time a cluster is found to be the closest from a sample, their distance, computed with the above formula, is added to the inertia. The right way would be not to update inertia incrementally but compute it using the labels and centers at the end of the iteration.

I think this is the reason of the failing sklearn test. With several inits, a same clustering with labels permuted can be found with different computed inertia. Can you confirm that the failure only occurs with lloyd algorithm (elkan uses the safe distance formula)?

@fschlimb @oleksandr-pavlyk @ogrisel

jeremiedbb avatar Dec 07 '18 18:12 jeremiedbb

Here is what I'm getting:

In [1]: import numpy as np
   ...: from sklearn.cluster import KMeans
   ...:
   ...: X = np.array([[-1.0001],[-0.9999],[0.9999],[1.0001]], dtype=np.float32)
   ...: km = KMeans(n_clusters=2, n_init=1, algorithm='full')
   ...: km.fit(X)
Out[1]:
KMeans_daal4py(algorithm='full', copy_x=True, init='k-means++', max_iter=300,
        n_clusters=2, n_init=1, n_jobs=None, precompute_distances='auto',
        random_state=None, tol=0.0001, verbose=0)

In [2]:

In [2]: km.cluster_centers_
Out[2]:
array([[ 1.],
       [-1.]], dtype=float32)

In [3]: km.inertia_
Out[3]: 8.002655e-08

In [4]: ((X - km.cluster_centers_[km.labels_])**2).sum()
Out[4]: 4.0013276e-08

Could you detail your setup please? Outputs of conda list scikit-learn and conda list daal would be useful.

oleksandr-pavlyk avatar Dec 12 '18 14:12 oleksandr-pavlyk

Here is the output of conda list scikit-learn and daal

Name                    Version               Build           Channel
scikit-learn            0.19.1                np114py36_35    intel
daal                    2019.0                intel_117       intel
pydaal                  2019.0.0.20180713     np114py36_0     intel

After upgrading, conda list gives

Name                    Version               Build           Channel
scikit-learn            0.20.0                py36_17         intel
daal                    2019.1                intel_144       intel
daal4py                 0.2019.1.1            py36_0          intel
pydaal                  2019.0.1.20181005     py36_1          intel

And then I get your result.

jeremiedbb avatar Dec 13 '18 10:12 jeremiedbb

I think I confirmed your intuition about the cause for failure of test_k_means_fit_predict for float32 with algo='full'. Due to numerical issues of computing distances in float32, computed centers end up slightly displaced relative to each other from run to run with the same initial cluster positions in multi-threaded environment. Due to this labels get assigned differently.

The v_measure_score of labels remain the same though, and the difference of intertias remain within the requested tolerance.

I asked DAAL team to take a look, but I think we should relax the test to check for v_measure_score rather than same labels.

oleksandr-pavlyk avatar Jan 09 '19 14:01 oleksandr-pavlyk

This is an example I ran:

import numpy as np

from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
from sklearn.utils.testing import assert_array_equal
import daal4py
from sklearn.cluster.k_means_ import _k_init
from sklearn.metrics import v_measure_score

seed = 3
n_clusters = 10
dtype = np.single
tol = 1e-8
max_iter = 300
X = make_blobs(n_samples=1000, n_features=10, centers=n_clusters, 
               random_state=seed)[0].astype(dtype, copy=False)
X = np.asarray(X)
n_init = 10

def get_starting_centroids():
    """Compute n_init starting values for given seed"""
    sum_of_squares_X = np.square(X).sum(axis=1)
    random_state = np.random.RandomState(seed)

    cc_list = []
    for k in range(n_init):
        starting_centroids_ = _k_init(X, n_clusters, sum_of_squares_X, random_state)
        cc_list.append(starting_centroids_)

    return cc_list


def run_kmeans(cc_list):

    kmeans_algo = daal4py.kmeans(
        nClusters = n_clusters,
        maxIterations = max_iter,
        assignFlag = True,
        accuracyThreshold = tol,
        fptype = 'float',
        method = 'defaultDense')

    best_labels, best_inertia, best_cluster_centers = None, None, None
    best_n_iter = -1

    for cc_i in cc_list:
        starting_centroids_ = cc_i

        res = kmeans_algo.compute(X, starting_centroids_)

        inertia = res.goalFunction[0,0]
        if best_inertia is None or inertia < best_inertia:
            best_labels = res.assignments.ravel()
            best_cluster_centers = res.centroids
            if n_init > 1:
                best_labels = best_labels.copy()
                best_cluster_centers = best_cluster_centers.copy()
            best_inertia = inertia
            best_n_iter = int(res.nIterations[0,0])

    return best_cluster_centers, best_labels, best_inertia, best_n_iter


starting_centroids_list = get_starting_centroids()
cc_1, l_1, bi_1, n_iter_1 = run_kmeans(starting_centroids_list)

it = 0
while True:
    cc_2, l_2, bi_2, n_iter_2 = run_kmeans(starting_centroids_list)

    if not np.all(l_1 == l_2):
        print("Failed at iteration {}".format(it))
        print((bi_1, bi_2, bi_1 - bi_2))
        assert n_iter_1 < max_iter
        assert n_iter_2 < max_iter
        print( (v_measure_score(l_1, l_2) - 1, v_measure_score(l_2, l_1) - 1) )
        assert_array_equal( np.take(cc_1, l_1, axis=0), np.take(cc_2, l_2, axis=0) )
        assert_array_equal(l_1, l_2)
    it += 1

Running it I am getting I sometimes get only differences in labels:

(def) [08:55:13 vmlin test_tmp]$ python daal4py_fit_predict.py
Failed at iteration 3469
(9769.283, 9769.283, 0.0)
(0.0, 0.0)
Traceback (most recent call last):
  File "daal4py_fit_predict.py", line 77, in <module>
    assert_array_equal(l_1, l_2)
  File "~/miniconda3_cb3/envs/def/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 865, in assert_array_equal
    verbose=verbose, header='Arrays are not equal')
  File "~/miniconda3_cb3/envs/def/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 789, in assert_array_compare
    raise AssertionError(msg)
AssertionError:
Arrays are not equal

(mismatch 90.0%)
 x: array([4, 1, 7, 2, 7, 9, 5, 9, 4, 4, 9, 5, 0, 5, 6, 6, 7, 8, 1, 8, 7, 5,
       8, 2, 0, 8, 9, 4, 3, 1, 9, 6, 4, 6, 4, 3, 9, 7, 3, 7, 9, 6, 2, 1,
       2, 4, 1, 1, 2, 4, 8, 7, 7, 7, 5, 8, 8, 3, 6, 7, 4, 2, 6, 4, 9, 9,...
 y: array([5, 7, 8, 1, 8, 6, 3, 6, 5, 5, 6, 3, 0, 3, 2, 2, 8, 9, 7, 9, 8, 3,
       9, 1, 0, 9, 6, 5, 4, 7, 6, 2, 5, 2, 5, 4, 6, 8, 4, 8, 6, 2, 1, 7,
       1, 5, 7, 7, 1, 5, 9, 8, 8, 8, 3, 9, 9, 4, 2, 8, 5, 1, 2, 5, 6, 6,...

and sometimes cluster centers associated with each point would be slightly off:

(def) [08:57:30 vmlin test_tmp]$ python daal4py_fit_predict.py
Failed at iteration 474
(9769.284, 9769.283, 0.0009765625)
(0.0, 0.0)
Traceback (most recent call last):
  File "daal4py_fit_predict.py", line 76, in <module>
    assert_array_equal( np.take(cc_1, l_1, axis=0), np.take(cc_2, l_2, axis=0) )
  File "~/miniconda3_cb3/envs/def/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 865, in assert_array_equal
    verbose=verbose, header='Arrays are not equal')
  File "~/miniconda3_cb3/envs/def/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 789, in assert_array_compare
    raise AssertionError(msg)
AssertionError:
Arrays are not equal

(mismatch 78.0%)
 x: array([-7.339982,  0.931489, -6.179488, ...,  1.052105, -4.918344,
       -1.632304], dtype=float32)
 y: array([-7.339981,  0.931489, -6.179488, ...,  1.052105, -4.918344,
       -1.632304], dtype=float32)

In all cases v_measure_score between pair of labels gives 1.0. (Displayed quantity is v_measure_score(labels_1, labels_2) - 1.0 ).

oleksandr-pavlyk avatar Jan 09 '19 15:01 oleksandr-pavlyk