dynamax icon indicating copy to clipboard operation
dynamax copied to clipboard

Add JAX kmeans implementation

Open gileshd opened this issue 1 year ago • 1 comments

Summary

This PR add an implementation of the k-means algorithm written in JAX.

By replacing the current calls to the sklearn implementation with this new implementation, we can now vmap and jit code which uses kmeans for initialization. We can also remove scikit-learn as a core dependency (it is still used in some demos).

Details

The changes are currently structured into different commits as follows:

  1. Add utility function for sklearn kmeans
    • This just adds as wrapper around the sklearn kmeans implementation to dynamax.utils.cluster.
  2. Update SSMs to use kmeans utility function
    • Refactor the calls to sklearn kmeans to use the new utility function.
  3. Add jax implementation of kmeans
    • Add a new implementation of kmeans to dynamax.utils.cluster which is compatible with JAX transformations.
    • Add some basic tests for this implementation.

Further Testing

It would be nice to be able to test if the new implementation does roughly as good a job as the sklearn implementation (it is considerably less complex). From playing about with it thus far once I added k-means++ initialisation it seemed to work pretty well.

I'm currently working on some test code where I patch in the new implementation and check that we can get about the same goodness of fit.

Final changes before merging

Before merging I will replace the current calls to kmeans_sklearn with kmeans_jax (and perhaps rename it) and remove the kmeans_sklearn function from dynamax.utils.cluster.

Questions:

  1. Should the kmeans implementation be marked as reserved for internal use only? - i.e. not part of the public API?
    • This way we have a bit more leeway to change the interface if necessary at a later date without breaking the API
    • Is this in fact already (implicitly?) the case with the contents of dynamax.utils?
  2. Any bright ideas on testing this?

Related issues:

Closes #315.

gileshd avatar Jul 21 '24 16:07 gileshd

"Should the kmeans implementation be marked as reserved for internal use only?" No! K-means in jax could be of independent interest, so it is worth making it a first class public citizen. But see also https://ott-jax.readthedocs.io/en/latest/_autosummary/ott.tools.k_means.k_means.html

Possibly also relevant https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.approx_max_k.html

murphyk avatar Jul 22 '24 15:07 murphyk