Question about derivatives
Hi!
When writing an objective that depends on an Optimizable that has no local dofs, I run into some error (see below) which is quite obscure to me. I would appreciate some help!
Here is a simple example of a class that generates the problem:
def pure_objective(gamma, current):
r = jnp.sqrt(gamma[:,0]**2+gamma[:,1]**2+gamma[:,2]**2)
return jnp.mean(r) * current
class TestObjective(Optimizable):
def __init__(self, coil):
self.coil = coil
self.J_jax = lambda gamma, current: pure_objective(gamma, current)
self.dobj_by_dgamma_vjp = lambda gamma, current, v: vjp(lambda g: self.J_jax(g, current), gamma)[1](v)[0]
self.dobj_by_dcurrent_vjp = lambda gamma, current, v: vjp(lambda c: self.J_jax(gamma, c), current)[1](v)[0]
super().__init__(depends_on=[coil])
def J(self):
gamma = self.coil.curve.gamma()
current = self.coil.current.get_value()
return self.J_jax(gamma, current)
def vjp(self, v):
gamma = self.coil.curve.gamma()
current = self.coil.current.get_value()
return Derivative(
{
self.coil.curve: self.dobj_by_dgamma_vjp(gamma, current, v),
self.coil.current: self.dobj_by_dcurrent_vjp(gamma, current, v)
}
)
def squared_objective(obj):
return obj**2
class ObjectiveSquared(Optimizable):
def __init__(self, obj):
self.obj = obj
self.J_jax = lambda x: squared_objective(x)
self.thisgrad = lambda x: grad(self.J_jax)(x)
super().__init__(depends_on=[obj])
def J(self):
return self.J_jax(self.obj.J())
@derivative_dec
def dJ(self):
x = self.obj.J()
grad = self.thisgrad( x )
return self.obj.vjp( grad )
Then, for a given coil, if we do
tt = TestObjective(coil)
sq = ObjectiveSquared(tt)
sq.dJ()
We get
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[43], line 1
----> 1 sq.dJ()
File [~/Github/simsopt/src/simsopt/_core/derivative.py:217](http://localhost:8889/lab/tree/Projects/CriticalCurrent/~/Github/simsopt/src/simsopt/_core/derivative.py#line=216), in derivative_dec.<locals>._derivative_dec(self, partials, *args, **kwargs)
215 return func(self, *args, **kwargs)
216 else:
--> 217 return func(self, *args, **kwargs)(self)
File [~/Github/simsopt/src/simsopt/_core/derivative.py:185](http://localhost:8889/lab/tree/Projects/CriticalCurrent/~/Github/simsopt/src/simsopt/_core/derivative.py#line=184), in Derivative.__call__(self, optim, as_derivative)
183 local_derivs = np.zeros(k.local_dof_size)
184 for opt in k.dofs.dep_opts():
--> 185 local_derivs += self.data[opt][opt.local_dofs_free_status]
186 keys.append(opt)
187 derivs.append(local_derivs)
File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/array.py:317](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/array.py#line=316), in ArrayImpl.__getitem__(self, idx)
315 return lax_numpy._rewriting_take(self, idx)
316 else:
--> 317 return lax_numpy._rewriting_take(self, idx)
File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4142](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py#line=4141), in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
4136 if (isinstance(aval, core.DShapedArray) and aval.shape == () and
4137 dtypes.issubdtype(aval.dtype, np.integer) and
4138 not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
4139 isinstance(arr.shape[0], int)):
4140 return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
-> 4142 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
4143 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
4144 unique_indices, mode, fill_value)
File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4220](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py#line=4219), in _split_index_for_jit(idx, shape)
4216 raise TypeError(f"JAX does not support string indexing; got {idx=}")
4218 # Expand any (concrete) boolean indices. We can then use advanced integer
4219 # indexing logic to handle them.
-> 4220 idx = _expand_bool_indices(idx, shape)
4222 leaves, treedef = tree_flatten(idx)
4223 dynamic = [None] * len(leaves)
File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4542](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py#line=4541), in _expand_bool_indices(idx, shape)
4540 expected_shape = shape[start: start + _ndim(i)]
4541 if i_shape != expected_shape:
-> 4542 raise IndexError("boolean index did not match shape of indexed array in index "
4543 f"{dim_number}: got {i_shape}, expected {expected_shape}")
4544 out.extend(np.where(i))
4545 else:
IndexError: boolean index did not match shape of indexed array in index 0: got (1,), expected ()
Investigating more, I found the following. When removing the @derivative_dec decorator, sq.dJ() does not return an error. The error really only occurs when gathering all the partial derivatives together, and doing sq.dJ()(sq) (or including the derivative_dec decorator)
@abaillod
I'll take a look at this in a couple of days.
One quick question. Why define self.J_Jax instead of directly using pure_objective?
@mbkumar Thank you for your answer.
Here I defined self.J_jax to mimic what I do in another, more complex class I am working on. In general objectives can have additional input parameters, and defining self.J_jax as a lambda function makes things more readable in my opinion.
For example, I could define
def pure_objective(gamma, current, k):
r = jnp.sqrt(gamma[:,0]**2+gamma[:,1]**2+gamma[:,2]**2)
return jnp.mean(r) * current**k
class TestObjective(Optimizable):
def __init__(self, coil, k):
self.coil = coil
self.k = k
self.J_jax = lambda gamma, current: pure_objective(gamma, current, self.k)
...
@mbkumar any news about this?
What is the dimension of the array returned by dJ()?
This is what I get
The derivative w.r.t the current has not the right shape - do I understand this correctly?
To give further context to this issues, this is related to the branch coil_forces in which @phuslage and I are working on a way to optimize for critical current.
You can have a look at these lines, where we attempt to take the derivative of a new objective, called CriticalCurrent. This thing works for any degree of freedom, excepted when taking the derivative w.r.t the coil current.
Anyway I would appreciate if someone could give me an example of how to implement these derivative correclty, or help me debug it.
Hi @abaillod, I am happy to dual debug with you, send me a calendar invite
Thank you. I just sent you an email to find a time.
@abaillod,
Sorry, I went radio silence after the last conversation. I got a lot of work the last couple of weeks and couldn't respond. Please add me as an optional attendee.
Bharat Medasani
On Mon, Aug 19, 2024 at 10:00 AM abaillod @.***> wrote:
Thank you. I just sent you an email to find a time.
— Reply to this email directly, view it on GitHub https://github.com/hiddenSymmetries/simsopt/issues/441#issuecomment-2296652863, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA62VEESVPEXYMILATYRV4TZSH27JAVCNFSM6AAAAABL2U6W4KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOJWGY2TEOBWGM . You are receiving this because you were mentioned.Message ID: @.***>
Thanks to Andrew and Bharat, we found how to fix it. In a nutshell, I have to pass an array to simsopt.field.coil.Current.vjp, not a scalar. The class TestObjective should then be
class TestObjective(Optimizable):
def __init__(self, coil):
self.coil = coil
self.J_jax = lambda gamma, current: pure_objective(gamma, current)
self.dobj_by_dgamma_vjp = lambda gamma, current, v: vjp(lambda g: self.J_jax(g, current), gamma)[1](v)[0]
self.dobj_by_dcurrent_vjp = lambda gamma, current, v: vjp(lambda c: self.J_jax(gamma, c), current)[1](v)[0]
super().__init__(depends_on=[coil])
def J(self):
gamma = self.coil.curve.gamma()
current = self.coil.current.get_value()
return self.J_jax(gamma, current)
def vjp(self, v):
gamma = self.coil.curve.gamma()
current = self.coil.current.get_value()
grad0 = self.dobj_by_dgamma_vjp(gamma, current, v)
grad1 = jnp.array([self.dobj_by_dcurrent_vjp(gamma, current, v)])
return self.coil.curve.dgamma_by_dcoeff_vjp(grad0) + self.coil.current.vjp(grad1)