jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Multiple constraints for ProjectedGradient?

Open howarth opened this issue 2 years ago • 3 comments

Is there a way to have multiple constraints with ProjectedGradient? For example can I have a constraint set be the intersection of the l2 sphere and non negative values?

howarth avatar Aug 09 '22 00:08 howarth

You need to implement the projection operator for the intersection of the two sets, which could be non trivial in general. In your case I haven't checked but projection_l2_sphere(projection_non_negative(.)) might give the correct result?

mblondel avatar Aug 11 '22 22:08 mblondel

According to POCS on Wikipedia, repeating x=projection_l2_sphere(projection_non_negative(x)) will reach a fixed point that is in the intersection.

Maybe we can add POCS wrapper to Jax to combine an arbitrary high number of terms in the intersection ?

Or maybe even a more evolved method like Dykstra. The code would be quite short and the termination condition easy to check. The implicit differentiation would be almost free: the fixed point of the procedure, which in turn would differentiate through each of the atomic projections associated to each set.

Algue-Rythme avatar Aug 19 '22 14:08 Algue-Rythme

Something like S3CM might also work? I had a version here (not differentiable).

There's also TOS for the deterministic case. Should have started by this maybe.

GeoffNN avatar Sep 12 '22 18:09 GeoffNN