jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Constraint violation causes L-BGFS-B to fail

Open gulls-on-parade opened this issue 1 year ago • 1 comments

I believe the line search internally used by jaxopt.LBFGSB is not respecting the bounds that are passed here, causing the objective function to generate NaNs and the overall optimization problem to fail. I am unsure if this a bug, or if I am doing something wrong. Any guidance is much appreciated.

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count={}'.format(os.cpu_count())
import jax as jx
from jax import jit
import jax.numpy as jnp
import jaxopt

#%% Helper functions

@jit
def normal_vol(strike, atmf, t, alpha, beta, rho, nu):
    eps = 1e-07  # Numerical tolerance
    f_av = jnp.sqrt(atmf * strike)

    fmkr = jnp.select([(jnp.abs(atmf - strike) > eps) & (jnp.abs(1 - beta) > eps), 
                       (jnp.abs(atmf - strike) > eps) & (jnp.abs(1 - beta) <= eps), 
                       jnp.abs(atmf - strike) <= eps],
                      [(1 - beta) * (atmf - strike) / (atmf**(1 - beta) - strike**(1 - beta)),
                       (atmf - strike) / jnp.log(atmf / strike),
                       strike**beta],
                      jnp.nan)
    
    zeta = nu * (atmf - strike) / (alpha * f_av**beta)
    
    zxz = jnp.select([jnp.abs(zeta) > eps, 
                      jnp.abs(zeta) <= eps],
                     [zeta / jnp.log(jnp.abs(((1 - 2 * rho * zeta + zeta**2)**.5 + zeta - rho) / (1 - rho))),
                      1.],
                     jnp.nan)
    
    a = - beta * (2 - beta) * alpha**2 / (24 * f_av**(2 - 2 * beta))
    b = rho * alpha * nu * beta / (4 * f_av**(1 - beta))
    c = (2 - 3 * rho**2) * nu**2 / 24

    vol = alpha * fmkr * zxz * (1 + (a + b + c) * t)

    return vol


@jit
def _obj(params, args):
    """Objective function to minimize the squared error between implied and model vols."""
    expiry, tail, strikes, vols, atmf, beta = args
    alpha, rho, nu = params
    vol_fitted = jx.vmap(normal_vol, (0, None, None, None, None, None, None))(strikes, atmf, expiry, alpha, beta, rho, nu)
    error = (vol_fitted - vols) * 1e4
    return jnp.sum(error**2)

#%% Example problem

data = [(0.09041095890410959,
  0.2465753424657534,
  jnp.array([0.0824076, 0.0849076, 0.0874076, 0.0899076, 0.0924076, 0.0949076,
         0.0974076, 0.0999076, 0.1024076, 0.1049076, 0.1074076, 0.1099076,
         0.1124076, 0.1149076, 0.1174076, 0.1199076, 0.1224076, 0.1249076,
         0.1274076, 0.1299076, 0.1324076, 0.1349076, 0.1374076, 0.1399076,
         0.1424076]),
  jnp.array([0.02100495, 0.02000676, 0.01897691, 0.01791351, 0.016814,
         0.01567488, 0.01449142, 0.0132571 , 0.01196264, 0.0105943 ,
         0.00913049, 0.00753422, 0.00573621, 0.00368666, 0.00298916,
         0.00351651, 0.00417858, 0.00485768, 0.00553383, 0.00620241,
         0.00686251, 0.00751431, 0.00815832, 0.00879512, 0.00942527]),
  0.11240760359238675,
  0.25),
 (0.09041095890410959,
  1.0027397260273974,
  jnp.array([0.07611851, 0.07861851, 0.08111851, 0.08361851, 0.08611851,
         0.08861851, 0.09111851, 0.09361851, 0.09611851, 0.09861851,
         0.10111851, 0.10361851, 0.10611851, 0.10861851, 0.11111851,
         0.11361851, 0.11611851, 0.11861851, 0.12111851, 0.12361851,
         0.12611851, 0.12861851, 0.13111851, 0.13361851, 0.13611851]),
  jnp.array([0.02571163, 0.02466922, 0.02359377, 0.02248411, 0.02133859,
         0.02015503, 0.01893064, 0.01766194, 0.01634481, 0.01497479,
         0.01354828, 0.01206712, 0.01055505, 0.00911653, 0.00807032,
         0.00778549, 0.00810574, 0.00870589, 0.0094221 , 0.01018791,
         0.01097495, 0.01177004, 0.01256661, 0.01336128, 0.01415225]),
  0.10611850901102435,
  0.25),
 (0.09041095890410959,
  2.0027397260273974,
  jnp.array([0.06970405, 0.07220405, 0.07470405, 0.07720405, 0.07970405,
         0.08220405, 0.08470405, 0.08720405, 0.08970405, 0.09220405,
         0.09470405, 0.09720405, 0.09970405, 0.10220405, 0.10470405,
         0.10720405, 0.10970405, 0.11220405, 0.11470405, 0.11720405,
         0.11970405, 0.12220405, 0.12470405, 0.12720405, 0.12970405]),
  jnp.array([0.02641612, 0.02545857, 0.02447167, 0.02345581, 0.02241125,
         0.02133829, 0.02023758, 0.01911054, 0.01796036, 0.01679381,
         0.01562486, 0.01448212, 0.01342213, 0.01254578, 0.01198868,
         0.01184018, 0.01206377, 0.01254688, 0.01318733, 0.01391874,
         0.01470226, 0.01551549, 0.0163453 , 0.01718381, 0.01802615]),
  0.09970405414511939,
  0.25)]


x0 = jnp.array([0.01, 0.00, 0.10])
bounds = (jnp.array([0.0001, -0.9999, 0.0001]), jnp.array([999, 0.9999, 999]))
args = data[0]

# This fails, as the objective function is producing nans when the step size immediately violates bounds as part of the implicit differentiation
solver = jaxopt.LBFGSB(fun=_obj)
results = solver.run(x0, bounds=bounds, args=args)


# However the objective function evaluates properly at x0
_obj(x0, args)

gulls-on-parade avatar Apr 17 '24 16:04 gulls-on-parade

I might be having the same issue too, please let me know if you learn something new! For now I am clipping to the bounds myself in my loss function.

charles-zhng avatar Apr 19 '24 16:04 charles-zhng