simsopt icon indicating copy to clipboard operation
simsopt copied to clipboard

Question about derivatives

Open abaillod opened this issue 1 year ago • 8 comments

Hi!

When writing an objective that depends on an Optimizable that has no local dofs, I run into some error (see below) which is quite obscure to me. I would appreciate some help!

Here is a simple example of a class that generates the problem:

def pure_objective(gamma, current):
    r = jnp.sqrt(gamma[:,0]**2+gamma[:,1]**2+gamma[:,2]**2)
    return jnp.mean(r) * current

class TestObjective(Optimizable):
    def __init__(self, coil):
        self.coil = coil
        self.J_jax = lambda gamma, current: pure_objective(gamma, current)
        self.dobj_by_dgamma_vjp = lambda gamma, current, v: vjp(lambda g: self.J_jax(g, current), gamma)[1](v)[0]
        self.dobj_by_dcurrent_vjp = lambda gamma, current, v: vjp(lambda c: self.J_jax(gamma, c), current)[1](v)[0] 
        
        super().__init__(depends_on=[coil])


    def J(self):
        gamma = self.coil.curve.gamma()
        current = self.coil.current.get_value()

        return self.J_jax(gamma, current)


    def vjp(self, v):
        gamma = self.coil.curve.gamma()
        current = self.coil.current.get_value()

        return Derivative(
            {
                self.coil.curve: self.dobj_by_dgamma_vjp(gamma, current, v),
                self.coil.current: self.dobj_by_dcurrent_vjp(gamma, current, v)
            }
        )


def squared_objective(obj):
    return obj**2

class ObjectiveSquared(Optimizable):
    def __init__(self, obj):
        self.obj = obj

        self.J_jax = lambda x: squared_objective(x)
        self.thisgrad = lambda x: grad(self.J_jax)(x)
        
        super().__init__(depends_on=[obj])

    def J(self):
        return self.J_jax(self.obj.J())

    @derivative_dec
    def dJ(self):
        x = self.obj.J()
        grad = self.thisgrad( x )

        return self.obj.vjp( grad )

Then, for a given coil, if we do

tt = TestObjective(coil)
sq = ObjectiveSquared(tt)
sq.dJ()

We get

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[43], line 1
----> 1 sq.dJ()

File [~/Github/simsopt/src/simsopt/_core/derivative.py:217](http://localhost:8889/lab/tree/Projects/CriticalCurrent/~/Github/simsopt/src/simsopt/_core/derivative.py#line=216), in derivative_dec.<locals>._derivative_dec(self, partials, *args, **kwargs)
    215     return func(self, *args, **kwargs)
    216 else:
--> 217     return func(self, *args, **kwargs)(self)

File [~/Github/simsopt/src/simsopt/_core/derivative.py:185](http://localhost:8889/lab/tree/Projects/CriticalCurrent/~/Github/simsopt/src/simsopt/_core/derivative.py#line=184), in Derivative.__call__(self, optim, as_derivative)
    183 local_derivs = np.zeros(k.local_dof_size)
    184 for opt in k.dofs.dep_opts():
--> 185     local_derivs += self.data[opt][opt.local_dofs_free_status]
    186     keys.append(opt)
    187 derivs.append(local_derivs)

File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/array.py:317](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/array.py#line=316), in ArrayImpl.__getitem__(self, idx)
    315   return lax_numpy._rewriting_take(self, idx)
    316 else:
--> 317   return lax_numpy._rewriting_take(self, idx)

File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4142](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py#line=4141), in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   4136     if (isinstance(aval, core.DShapedArray) and aval.shape == () and
   4137         dtypes.issubdtype(aval.dtype, np.integer) and
   4138         not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
   4139         isinstance(arr.shape[0], int)):
   4140       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
-> 4142 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
   4143 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   4144                unique_indices, mode, fill_value)

File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4220](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py#line=4219), in _split_index_for_jit(idx, shape)
   4216   raise TypeError(f"JAX does not support string indexing; got {idx=}")
   4218 # Expand any (concrete) boolean indices. We can then use advanced integer
   4219 # indexing logic to handle them.
-> 4220 idx = _expand_bool_indices(idx, shape)
   4222 leaves, treedef = tree_flatten(idx)
   4223 dynamic = [None] * len(leaves)

File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4542](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py#line=4541), in _expand_bool_indices(idx, shape)
   4540     expected_shape = shape[start: start + _ndim(i)]
   4541     if i_shape != expected_shape:
-> 4542       raise IndexError("boolean index did not match shape of indexed array in index "
   4543                        f"{dim_number}: got {i_shape}, expected {expected_shape}")
   4544     out.extend(np.where(i))
   4545 else:

IndexError: boolean index did not match shape of indexed array in index 0: got (1,), expected ()

abaillod avatar Aug 01 '24 14:08 abaillod

Investigating more, I found the following. When removing the @derivative_dec decorator, sq.dJ() does not return an error. The error really only occurs when gathering all the partial derivatives together, and doing sq.dJ()(sq) (or including the derivative_dec decorator)

abaillod avatar Aug 01 '24 14:08 abaillod

@abaillod I'll take a look at this in a couple of days. One quick question. Why define self.J_Jax instead of directly using pure_objective?

mbkumar avatar Aug 06 '24 20:08 mbkumar

@mbkumar Thank you for your answer.

Here I defined self.J_jax to mimic what I do in another, more complex class I am working on. In general objectives can have additional input parameters, and defining self.J_jax as a lambda function makes things more readable in my opinion.

For example, I could define

def pure_objective(gamma, current, k):
    r = jnp.sqrt(gamma[:,0]**2+gamma[:,1]**2+gamma[:,2]**2)
    return  jnp.mean(r) * current**k

class TestObjective(Optimizable):
    def __init__(self, coil, k):
        self.coil = coil
        self.k = k
        self.J_jax = lambda gamma, current: pure_objective(gamma, current, self.k)
        ...

abaillod avatar Aug 08 '24 15:08 abaillod

@mbkumar any news about this?

abaillod avatar Aug 12 '24 14:08 abaillod

What is the dimension of the array returned by dJ()?

andrewgiuliani avatar Aug 13 '24 12:08 andrewgiuliani

This is what I get

image

The derivative w.r.t the current has not the right shape - do I understand this correctly?

abaillod avatar Aug 13 '24 17:08 abaillod

To give further context to this issues, this is related to the branch coil_forces in which @phuslage and I are working on a way to optimize for critical current.

You can have a look at these lines, where we attempt to take the derivative of a new objective, called CriticalCurrent. This thing works for any degree of freedom, excepted when taking the derivative w.r.t the coil current.

Anyway I would appreciate if someone could give me an example of how to implement these derivative correclty, or help me debug it.

abaillod avatar Aug 14 '24 19:08 abaillod

Hi @abaillod, I am happy to dual debug with you, send me a calendar invite

andrewgiuliani avatar Aug 14 '24 22:08 andrewgiuliani

Thank you. I just sent you an email to find a time.

abaillod avatar Aug 19 '24 13:08 abaillod

@abaillod,

Sorry, I went radio silence after the last conversation. I got a lot of work the last couple of weeks and couldn't respond. Please add me as an optional attendee.

Bharat Medasani

On Mon, Aug 19, 2024 at 10:00 AM abaillod @.***> wrote:

Thank you. I just sent you an email to find a time.

— Reply to this email directly, view it on GitHub https://github.com/hiddenSymmetries/simsopt/issues/441#issuecomment-2296652863, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA62VEESVPEXYMILATYRV4TZSH27JAVCNFSM6AAAAABL2U6W4KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOJWGY2TEOBWGM . You are receiving this because you were mentioned.Message ID: @.***>

mbkumar avatar Aug 19 '24 14:08 mbkumar

Thanks to Andrew and Bharat, we found how to fix it. In a nutshell, I have to pass an array to simsopt.field.coil.Current.vjp, not a scalar. The class TestObjective should then be

class TestObjective(Optimizable):
    def __init__(self, coil):
        self.coil = coil
        self.J_jax = lambda gamma, current: pure_objective(gamma, current)
        self.dobj_by_dgamma_vjp = lambda gamma, current, v: vjp(lambda g: self.J_jax(g, current), gamma)[1](v)[0]
        self.dobj_by_dcurrent_vjp = lambda gamma, current, v: vjp(lambda c: self.J_jax(gamma, c), current)[1](v)[0] 
        
        super().__init__(depends_on=[coil])


    def J(self):
        gamma = self.coil.curve.gamma()
        current = self.coil.current.get_value()

        return self.J_jax(gamma, current)


    def vjp(self, v):
        gamma = self.coil.curve.gamma()
        current = self.coil.current.get_value()

        grad0 = self.dobj_by_dgamma_vjp(gamma, current, v)
        grad1 = jnp.array([self.dobj_by_dcurrent_vjp(gamma, current, v)])

        return self.coil.curve.dgamma_by_dcoeff_vjp(grad0) + self.coil.current.vjp(grad1)

abaillod avatar Aug 20 '24 18:08 abaillod