jaxopt
jaxopt copied to clipboard
[WIP] Non Negative Matrix Factorization (NMF) and NNLS
What's new ?
Two new classes: NMF for Non Negative Matrix Factorization. This non-convex problem is NP-hard [3], but it is bi-convex:
Solves :: min_{H1, H2} 0.5 * ||Y - H1 @ H2.T||_F^2 s.t. H1 >= 0, H2 >= 0
Hence, it is solved with alternating minimization of two convex sub-problems: Non Negative Least Squares problems (NNLS).
Solves ::
min_H 0.5 * ||Y - W @ H.T||_F^2
s.t. H >= 0
The implementation is based on the pseudo code found in [1] based on ADMM [2].
[1] Huang, K., Sidiropoulos, N.D. and Liavas, A.P., 2016. A flexible and efficient algorithmic framework for constrained matrix and tensor factorization. IEEE Transactions on Signal Processing, 64(19), pp.5052-5065.
[2] Boyd, S., Parikh, N., Chu, E., Peleato, B. and Eckstein, J., 2010. Distributed Optimization and Statistical Learning via the Alternating Direction Method of Multipliers. Machine Learning, 3(1), pp.1-122.
[3] Vavasis, S.A., 2010. On the complexity of nonnegative matrix factorization. SIAM Journal on Optimization, 20(3), pp.1364-1377.
Difference with Sklearn
Like Sklearn, this implementation is based on alternating minimization of the two problems. However Sklearn's relies on Block Coordinate Descent, whereas this implementation is based on ADMM and provides dual variables.
At first, I thought those dual variables were needed for implicit differentiation (with KKT conditions). I am just realizing there was another approach: the fix point of a proximity operator ! I don't know which one is the fastest or the more stable.
Implementation choices
Since the problem is non convex, the starting point is very important. It is a bit tricky to find a good initialization, so currently the implement defaults to thes ones of Sklearn.
The nnls_solver solver part of NMF class allows to switch between different solvers for the NNLS problem: NMF is more a "meta" algorithm for matrix factorization. Note that the pseudo code of [1] supports arbitrary products of tensors, not only the case Y-UW.T.
I noticed that the heuristics provided by [1] for step size tuning were different from the ones of OSQP: in doubt, I proposed both.
Why a separate class for NNLS ?
NNLS is special case of quadratic program.
- It could have been handled by OSQP. But OSQP is very general (can handle arbitrary linear constraints). It may have been overkill to use it.
- Currently it is a simple quadratic program, but NMF usually support additional constraints such as
l2regularization,l1regulairzation, Huber fitting, masking, etc... see exemples given in page 9 of [1]. With orthogonal constraints onUthe NMF becomes equivalent to K-means, which allows to define a differentiable K-mean layer (in the spirit of [5]): we could outperform [6] for example. A separate class for NNLS facilitates the add of these parameters. - BoxCDQP could have done the trick - unfortunately I implemented NNLS before the release of BoxCDQP. It might be worth doing the change though.
[4] Ding, C., He, X. and Simon, H.D., 2005, April. On the equivalence of nonnegative matrix factorization and spectral clustering. In Proceedings of the 2005 SIAM international conference on data mining (pp. 606-610). Society for Industrial and Applied Mathematics.
[5] Genevay, A., Dulac-Arnold, G. and Vert, J.P., 2019. Differentiable deep clustering with cluster size constraints. arXiv preprint arXiv:1910.09036.
[6] Cho, M., Vahid, K.A., Adya, S. and Rastegari, M., 2021. DKM: Differentiable K-Means Clustering Layer for Neural Network Compression. arXiv preprint arXiv:2108.12659.
TODO
- Fix the computation of implicit diff : some tests are failing
- Compare ADMM-NNLS with BoxCDQP for speed, precision, and implicit differentiation
- In other PR: add support for regularizations (
l1,l2, etc..)
Discussions are welcome, specially on the ADMM versus Coordinate Descent
We need to decide if such code is in the scope of the library (it adds maintenance burden and reaching state-of-the-art performance will be a lot of work). Block CD is the state-of-the-art on CPU, not sure on GPU. BTW, we have a short NMF example using Block CD already: https://github.com/google/jaxopt/blob/main/examples/constrained/nmf.py. Regarding implicit diff, my initial thought would have been to use the proximal gradient fixed point too.
We need to decide if such code is in the scope of the library (it adds maintenance burden and reaching state-of-the-art performance will be a lot of work)
How do you see this library ? A general library for constrained/unconstrained/stochastic convex optimization ? Like a mix of scipy.optimize, cvxpy and cvxopt ? Or do you plan tu support some solvers for specific problems (e.g those of sklearn.decomposition) ?
In this case NMF is just a special case of an algorithm that be implemented by general purpose optimizers, and could be moved to folder examples.
I had a quick exchange with @froystig and @fabianp. We prefer to focus on generic solvers in the library. For instance, an SVM solver is out of scope but BoxCDQP is fine because it is a generic problem class. That said, we would welcome the contribution as an example. Rather than classes, using functions and the custom_fixed_point decorator would be better, to reduce boiler plate. Lastly, we are in the process of switching to notebooks, so it would be better to create the example directly as a notebook (https://jaxopt.github.io/stable/developer.html). Bonus point if you can merge the existing NMF example in your new notebook :)