jaxopt
jaxopt copied to clipboard
Constraint violation causes L-BGFS-B to fail
I believe the line search internally used by jaxopt.LBFGSB is not respecting the bounds that are passed here, causing the objective function to generate NaNs and the overall optimization problem to fail. I am unsure if this a bug, or if I am doing something wrong. Any guidance is much appreciated.
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count={}'.format(os.cpu_count())
import jax as jx
from jax import jit
import jax.numpy as jnp
import jaxopt
#%% Helper functions
@jit
def normal_vol(strike, atmf, t, alpha, beta, rho, nu):
eps = 1e-07 # Numerical tolerance
f_av = jnp.sqrt(atmf * strike)
fmkr = jnp.select([(jnp.abs(atmf - strike) > eps) & (jnp.abs(1 - beta) > eps),
(jnp.abs(atmf - strike) > eps) & (jnp.abs(1 - beta) <= eps),
jnp.abs(atmf - strike) <= eps],
[(1 - beta) * (atmf - strike) / (atmf**(1 - beta) - strike**(1 - beta)),
(atmf - strike) / jnp.log(atmf / strike),
strike**beta],
jnp.nan)
zeta = nu * (atmf - strike) / (alpha * f_av**beta)
zxz = jnp.select([jnp.abs(zeta) > eps,
jnp.abs(zeta) <= eps],
[zeta / jnp.log(jnp.abs(((1 - 2 * rho * zeta + zeta**2)**.5 + zeta - rho) / (1 - rho))),
1.],
jnp.nan)
a = - beta * (2 - beta) * alpha**2 / (24 * f_av**(2 - 2 * beta))
b = rho * alpha * nu * beta / (4 * f_av**(1 - beta))
c = (2 - 3 * rho**2) * nu**2 / 24
vol = alpha * fmkr * zxz * (1 + (a + b + c) * t)
return vol
@jit
def _obj(params, args):
"""Objective function to minimize the squared error between implied and model vols."""
expiry, tail, strikes, vols, atmf, beta = args
alpha, rho, nu = params
vol_fitted = jx.vmap(normal_vol, (0, None, None, None, None, None, None))(strikes, atmf, expiry, alpha, beta, rho, nu)
error = (vol_fitted - vols) * 1e4
return jnp.sum(error**2)
#%% Example problem
data = [(0.09041095890410959,
0.2465753424657534,
jnp.array([0.0824076, 0.0849076, 0.0874076, 0.0899076, 0.0924076, 0.0949076,
0.0974076, 0.0999076, 0.1024076, 0.1049076, 0.1074076, 0.1099076,
0.1124076, 0.1149076, 0.1174076, 0.1199076, 0.1224076, 0.1249076,
0.1274076, 0.1299076, 0.1324076, 0.1349076, 0.1374076, 0.1399076,
0.1424076]),
jnp.array([0.02100495, 0.02000676, 0.01897691, 0.01791351, 0.016814,
0.01567488, 0.01449142, 0.0132571 , 0.01196264, 0.0105943 ,
0.00913049, 0.00753422, 0.00573621, 0.00368666, 0.00298916,
0.00351651, 0.00417858, 0.00485768, 0.00553383, 0.00620241,
0.00686251, 0.00751431, 0.00815832, 0.00879512, 0.00942527]),
0.11240760359238675,
0.25),
(0.09041095890410959,
1.0027397260273974,
jnp.array([0.07611851, 0.07861851, 0.08111851, 0.08361851, 0.08611851,
0.08861851, 0.09111851, 0.09361851, 0.09611851, 0.09861851,
0.10111851, 0.10361851, 0.10611851, 0.10861851, 0.11111851,
0.11361851, 0.11611851, 0.11861851, 0.12111851, 0.12361851,
0.12611851, 0.12861851, 0.13111851, 0.13361851, 0.13611851]),
jnp.array([0.02571163, 0.02466922, 0.02359377, 0.02248411, 0.02133859,
0.02015503, 0.01893064, 0.01766194, 0.01634481, 0.01497479,
0.01354828, 0.01206712, 0.01055505, 0.00911653, 0.00807032,
0.00778549, 0.00810574, 0.00870589, 0.0094221 , 0.01018791,
0.01097495, 0.01177004, 0.01256661, 0.01336128, 0.01415225]),
0.10611850901102435,
0.25),
(0.09041095890410959,
2.0027397260273974,
jnp.array([0.06970405, 0.07220405, 0.07470405, 0.07720405, 0.07970405,
0.08220405, 0.08470405, 0.08720405, 0.08970405, 0.09220405,
0.09470405, 0.09720405, 0.09970405, 0.10220405, 0.10470405,
0.10720405, 0.10970405, 0.11220405, 0.11470405, 0.11720405,
0.11970405, 0.12220405, 0.12470405, 0.12720405, 0.12970405]),
jnp.array([0.02641612, 0.02545857, 0.02447167, 0.02345581, 0.02241125,
0.02133829, 0.02023758, 0.01911054, 0.01796036, 0.01679381,
0.01562486, 0.01448212, 0.01342213, 0.01254578, 0.01198868,
0.01184018, 0.01206377, 0.01254688, 0.01318733, 0.01391874,
0.01470226, 0.01551549, 0.0163453 , 0.01718381, 0.01802615]),
0.09970405414511939,
0.25)]
x0 = jnp.array([0.01, 0.00, 0.10])
bounds = (jnp.array([0.0001, -0.9999, 0.0001]), jnp.array([999, 0.9999, 999]))
args = data[0]
# This fails, as the objective function is producing nans when the step size immediately violates bounds as part of the implicit differentiation
solver = jaxopt.LBFGSB(fun=_obj)
results = solver.run(x0, bounds=bounds, args=args)
# However the objective function evaluates properly at x0
_obj(x0, args)
I might be having the same issue too, please let me know if you learn something new! For now I am clipping to the bounds myself in my loss function.