simsopt icon indicating copy to clipboard operation
simsopt copied to clipboard

Implementing a constraint on max current

Open abaillod opened this issue 1 year ago • 2 comments

Hi,

I tried implementing a constraint on the max current in a coil. The objective looks like this:

def current_penalty_pure(I, threshold):
    return jnp.maximum(abs(I) - threshold, 0)**2

class CurrentPenalty(Optimizable):
    """
    A :obj:`CurrentPenalty` can be used to penalize
    large currents in coils.
    """
    def __init__(self, current, threshold=0):
        self.current = current
        self.threshold = threshold

        self.J_jax = lambda I: current_penalty_pure(I, self.threshold)
        self.this_grad = lambda I: grad(self.J_jax, argnums=0)(I)

        super().__init__(depends_on=[current])

    def J(self):
        return self.J_jax(self.current.x[0])

    @derivative_dec
    def dJ(self):
        grad0 = self.this_grad(self.current.x[0])
        return self.current.vjp(grad0)

However, when running the following simple test,

from simsopt.field import CurrentPenalty

c = Current(1e5)
test = CurrentPenalty(c)
test.dJ()

I get the error

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[3], line 5
      3 c = Current(1e5)
      4 test = CurrentPenalty(c)
----> 5 test.dJ()

File ~/Github/simsopt/src/simsopt/_core/derivative.py:217, in derivative_dec.<locals>._derivative_dec(self, partials, *args, **kwargs)
    215     return func(self, *args, **kwargs)
    216 else:
--> 217     return func(self, *args, **kwargs)(self)

File ~/Github/simsopt/src/simsopt/_core/derivative.py:185, in Derivative.__call__(self, optim, as_derivative)
    183 local_derivs = np.zeros(k.local_dof_size)
    184 for opt in k.dofs.dep_opts():
--> 185     local_derivs += self.data[opt][opt.local_dofs_free_status]
    186     keys.append(opt)
    187 derivs.append(local_derivs)

File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/array.py:317, in ArrayImpl.__getitem__(self, idx)
    315   return lax_numpy._rewriting_take(self, idx)
    316 else:
--> 317   return lax_numpy._rewriting_take(self, idx)

File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4142, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   4136     if (isinstance(aval, core.DShapedArray) and aval.shape == () and
   4137         dtypes.issubdtype(aval.dtype, np.integer) and
   4138         not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
   4139         isinstance(arr.shape[0], int)):
   4140       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
-> 4142 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
   4143 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   4144                unique_indices, mode, fill_value)

File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4220, in _split_index_for_jit(idx, shape)
   4216   raise TypeError(f"JAX does not support string indexing; got {idx=}")
   4218 # Expand any (concrete) boolean indices. We can then use advanced integer
   4219 # indexing logic to handle them.
-> 4220 idx = _expand_bool_indices(idx, shape)
   4222 leaves, treedef = tree_flatten(idx)
   4223 dynamic = [None] * len(leaves)

File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4542, in _expand_bool_indices(idx, shape)
   4540     expected_shape = shape[start: start + _ndim(i)]
   4541     if i_shape != expected_shape:
-> 4542       raise IndexError("boolean index did not match shape of indexed array in index "
   4543                        f"{dim_number}: got {i_shape}, expected {expected_shape}")
   4544     out.extend(np.where(i))
   4545 else:

IndexError: boolean index did not match shape of indexed array in index 0: got (1,), expected ()

Does anyone have an idea how to fix it?

abaillod avatar Jun 05 '24 18:06 abaillod

not sure, but one suggestion would be to implement this without Jax since it's such a simple penalty

andrewgiuliani avatar Jun 05 '24 19:06 andrewgiuliani

Looks like the derivative decorator trips over the fact that the optimizable only has one dof. The Jax internals do not make it very clear, but check if current.dofs.dep_opts()[0].dofs_free_status is a boolean instead of a length one array thereof

smiet avatar Jun 05 '24 20:06 smiet