jaxopt
jaxopt copied to clipboard
Multiple constraints for ProjectedGradient?
Is there a way to have multiple constraints with ProjectedGradient? For example can I have a constraint set be the intersection of the l2 sphere and non negative values?
You need to implement the projection operator for the intersection of the two sets, which could be non trivial in general. In your case I haven't checked but projection_l2_sphere(projection_non_negative(.))
might give the correct result?
According to POCS on Wikipedia, repeating x=projection_l2_sphere(projection_non_negative(x))
will reach a fixed point that is in the intersection.
Maybe we can add POCS wrapper to Jax to combine an arbitrary high number of terms in the intersection ?
Or maybe even a more evolved method like Dykstra. The code would be quite short and the termination condition easy to check. The implicit differentiation would be almost free: the fixed point of the procedure, which in turn would differentiate through each of the atomic projections associated to each set.