deas-mhumhna
deas-mhumhna
What are people's thoughts on how constrained optimization should be handled? I ask because I'm writing an implementation of the parametric simplex and affine scaling methods for linear optimization, and...
Any interest in refactoring the binomial and normal nodes to use `gonum/stat` inferfaces and allow a larger variety of/user defined distributions? Perhaps create a separate subpackage just for random numbers?
Can you currently modify gradients wrt specific nodes? My particular application is performing variance reduction on stochastic gradients. For this, I need to set gradients wrt specific nodes i.e. `GradWRT`...
I'm trying to use `jaxopt.OSQP` as part of the projection step of another training algorithm. The QP is ```math \begin{equation} \begin{split} \text{min} & \quad ||\theta-\theta_0||^2 \\ \text{st.} & \lim_{x \to\infty}f(x,...
I just discovered the Diffrax package and it's great! However, I'm encountering an issue where the gradient evaluation is 40-80 times slower than the forward pass for my particular network...
```python import jax.numpy as jnp import tree_math as tm def f(x, y): return x, y x = y = tm.Vector(jnp.array(0.)) tm.unwrap(f, out_vectors = (True, False))(x, y) # (tree_math.Vector(DeviceArray(0., dtype=float32, weak_type=True)),...
```python import jaxopt fun = lambda z, params_obj: 0.5 * z @ z + params_obj[0] @ z - 1 matvec_A = lambda params_A, z: (z, ) solver = jaxopt.BoxOSQP(matvec_A=matvec_A, fun=fun,...