Add JAX kmeans implementation
Summary
This PR add an implementation of the k-means algorithm written in JAX.
By replacing the current calls to the sklearn implementation with this new implementation, we can now vmap and jit code which uses kmeans for initialization. We can also remove scikit-learn as a core dependency (it is still used in some demos).
Details
The changes are currently structured into different commits as follows:
-
Add utility function for sklearn kmeans
- This just adds as wrapper around the
sklearnkmeans implementation todynamax.utils.cluster.
- This just adds as wrapper around the
-
Update SSMs to use kmeans utility function
- Refactor the calls to
sklearnkmeans to use the new utility function.
- Refactor the calls to
-
Add jax implementation of kmeans
- Add a new implementation of kmeans to
dynamax.utils.clusterwhich is compatible with JAX transformations. - Add some basic tests for this implementation.
- Add a new implementation of kmeans to
Further Testing
It would be nice to be able to test if the new implementation does roughly as good a job as the sklearn implementation (it is considerably less complex). From playing about with it thus far once I added k-means++ initialisation it seemed to work pretty well.
I'm currently working on some test code where I patch in the new implementation and check that we can get about the same goodness of fit.
Final changes before merging
Before merging I will replace the current calls to kmeans_sklearn with kmeans_jax (and perhaps rename it) and remove the kmeans_sklearn function from dynamax.utils.cluster.
Questions:
- Should the kmeans implementation be marked as reserved for internal use only? - i.e. not part of the public API?
- This way we have a bit more leeway to change the interface if necessary at a later date without breaking the API
- Is this in fact already (implicitly?) the case with the contents of
dynamax.utils?
- Any bright ideas on testing this?
Related issues:
Closes #315.
"Should the kmeans implementation be marked as reserved for internal use only?" No! K-means in jax could be of independent interest, so it is worth making it a first class public citizen. But see also https://ott-jax.readthedocs.io/en/latest/_autosummary/ott.tools.k_means.k_means.html
Possibly also relevant https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.approx_max_k.html