scikit-learn-extra icon indicating copy to clipboard operation
scikit-learn-extra copied to clipboard

Add neighbors algorithm based on NSW graphs

Open LeoSvalov opened this issue 2 years ago • 2 comments

Good afternoon!

I would like to add the algorithm to do the approximate nearest neighbors search.

The method is based on Navigable small world graphs (NSW graphs) that tends to demonstrate better performance in the high-dimensional data space [1] in comparison with existing Scikit-Learn KDTree and BallTree methods, starting from data dimension D > 50.

The API of the algorithm is very similar to the existing alternatives, despite the fact that NSWGraph also can be utilized in KNearestNeighbors classifier manner, as the base estimator paradigm (fit/predict) is included.

Possible ways to use the method:

from sklearn_extra.neighbors import NSWGraph
from sklearn.datasets import load_iris
import numpy as np
  1. As object to query k-nearest neighbors.
rng = np.random.RandomState(10)
X = rng.random_sample((50, 128))
nswgraph = NSWGraph()
nswgraph.build(X)
X_val = rng.random_sample((5, 128))
dists, inds = nswgraph.query(X_val, k=3)
  1. As neighbors estimator with taking into account the target classes of the data.
X,y = load_iris(return_X_y=True)
estimator = NSWGraph()
estimator.fit(X,y)
y_pred = estimator.predict(X)

References

[1] Malkov, Y., Ponomarenko, A., Logvinov, A., & Krylov, V. (2014). Approximate nearest neighbor algorithm based on navigable small world graphs. Information Systems, 45, 61-68.

LeoSvalov avatar May 24 '22 10:05 LeoSvalov

Thanks for the contribution! While I am not yet sure if it would meet a consensus of maintainers to be accepted in the scikit-learn code base, it would surely help to run some benchmarks.

If the speed of your PR can demonstrate to be approximately competitive with alternative implementations, it would surely help convince maintainers that it is worth investing their time to review the PR and accept the long term maintenance burden that will come with a new method.

Ideally the benchmarks could be based on this existing infrastructure:

  • http://ann-benchmarks.com/
  • https://github.com/erikbern/ann-benchmarks/

In particular I would be interested in a comparison with nswlib's implementation and alternative method not based on NSW graphs such as https://github.com/lmcinnes/pynndescent.

ogrisel avatar Jun 10 '22 18:06 ogrisel

I just realised that this is not the scikit-learn/scikit-learn repo but the scikit-learn-extra repo as I arrived to this PR from the scikit-learn/scikit-learn#23450 issue from the main scikit-learn issue tracker.

I think it would be great to have an implementation of NSW nearest neighbors in scikit-learn-extra. But before reviewing this PR, I would like to see some performance benchmark results as requested above.

ogrisel avatar Jun 13 '22 07:06 ogrisel