ModOpt
ModOpt copied to clipboard
[BUG] Fista restarted solver diverges after having reached solution on Lasso problem
reproduce with
import celer
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import LabelBinarizer
from modopt.opt.algorithms import ForwardBackward
from modopt.opt.proximity import SparseThreshold
from modopt.opt.linear import Identity
from modopt.opt.gradient import GradBasic
X, y = fetch_openml("leukemia", return_X_y=True)
X = X.to_numpy()
y = LabelBinarizer().fit_transform(y)[:, 0].astype(X.dtype)
lmbd = 0.5 * np.max(np.abs(X.T @ y))
# problem is easy, celer fits to machine precision in 1 iter
x_star = celer.Lasso(alpha=lmbd/len(y), tol=1e-14, verbose=1,
fit_intercept=False).fit(X, y).coef_
restart_strategy = "adaptive-1"
min_beta = None
s_greedy = None
p_lazy = 1 / 30
q_lazy = 1 / 10
def op(w):
return X @ w
fb = ForwardBackward(
x=np.zeros(X.shape[1]),
grad=GradBasic(
input_data=y, op=op,
trans_op=lambda res: X.T@res,
input_data_writeable=True,
),
prox=SparseThreshold(Identity(), lmbd),
beta_param=1.0,
min_beta=min_beta,
metric_call_period=None,
restart_strategy=restart_strategy,
xi_restart=0.96,
s_greedy=s_greedy,
p_lazy=p_lazy,
q_lazy=q_lazy,
auto_iterate=False,
progress=False,
cost=None,
)
L = np.linalg.norm(X, ord=2) ** 2
beta_param = 1 / L
fb.beta_param = beta_param
fb._beta = fb.step_size or beta_param
increment = 10
it = 0
iterations = []
distances = []
support_sizes = []
while it < 1700:
it += increment
fb.iterate(max_iter=increment)
x = fb.x_final
distances.append(norm(x - x_star))
support_sizes.append((x != 0).sum())
iterations.append(it)
plt.close('all')
fig, axarr = plt.subplots(2, 1, sharex=True, constrained_layout=True)
axarr[0].semilogy(iterations, distances)
axarr[0].set_ylabel("distance to solution")
axarr[1].plot(iterations, support_sizes)
axarr[1].set_ylabel("iterate support size")
axarr[1].set_xlabel("iterations")
plt.show(block=False)
ping @agramfort @tommoral