skglm icon indicating copy to clipboard operation
skglm copied to clipboard

ENH - Automatic support of L2 regulrization in Penalties

Open Badr-MOUFAD opened this issue 2 years ago • 0 comments

Given a penalty $f: \mathbb{R}^n \rightarrow \mathbb{R}$, that is already implemented in the package, It is possible to endow it with L2 regularization to get $\Omega = f + \frac{\mu}{2} \lVert \cdot \rVert$

Indeed for a step, $\sigma$ and gradient $\mathrm{grad}$, the proximal operator and distance to subdifferential can be written using prox and subdiffdistance of $f$

$$ \mathrm{prox}{\Omega, \sigma}(x) = \mathrm{prox}{f, \frac{\sigma}{1 + \sigma \mu}}(\frac{x}{1 + \sigma \mu}) $$

$$ \mathrm{dist}{\partial \Omega(x)}(-\mathrm{grad}) = \mathrm{dist}{\partial f(x)}(-\mathrm{grad} - \mu x) $$

Implementation

This can be implemented either through inheritance or a class decorator. This PR provides a POC of the second approach. Hence to add support for L2 regularization, one only needs to decorate the penalty with overload_with_l2.

Help needed

I unittested to the logic and implementation and everything works as expected. However, I'm running into problems when jit-compiling the class as numba doesn't support *args, **kwargs, which are mandatory to overload the constructor of the penalty.

Any workaround to bypass that?

Badr-MOUFAD avatar Apr 06 '23 16:04 Badr-MOUFAD