simsopt
simsopt copied to clipboard
Implementing a constraint on max current
Hi,
I tried implementing a constraint on the max current in a coil. The objective looks like this:
def current_penalty_pure(I, threshold):
return jnp.maximum(abs(I) - threshold, 0)**2
class CurrentPenalty(Optimizable):
"""
A :obj:`CurrentPenalty` can be used to penalize
large currents in coils.
"""
def __init__(self, current, threshold=0):
self.current = current
self.threshold = threshold
self.J_jax = lambda I: current_penalty_pure(I, self.threshold)
self.this_grad = lambda I: grad(self.J_jax, argnums=0)(I)
super().__init__(depends_on=[current])
def J(self):
return self.J_jax(self.current.x[0])
@derivative_dec
def dJ(self):
grad0 = self.this_grad(self.current.x[0])
return self.current.vjp(grad0)
However, when running the following simple test,
from simsopt.field import CurrentPenalty
c = Current(1e5)
test = CurrentPenalty(c)
test.dJ()
I get the error
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[3], line 5
3 c = Current(1e5)
4 test = CurrentPenalty(c)
----> 5 test.dJ()
File ~/Github/simsopt/src/simsopt/_core/derivative.py:217, in derivative_dec.<locals>._derivative_dec(self, partials, *args, **kwargs)
215 return func(self, *args, **kwargs)
216 else:
--> 217 return func(self, *args, **kwargs)(self)
File ~/Github/simsopt/src/simsopt/_core/derivative.py:185, in Derivative.__call__(self, optim, as_derivative)
183 local_derivs = np.zeros(k.local_dof_size)
184 for opt in k.dofs.dep_opts():
--> 185 local_derivs += self.data[opt][opt.local_dofs_free_status]
186 keys.append(opt)
187 derivs.append(local_derivs)
File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/array.py:317, in ArrayImpl.__getitem__(self, idx)
315 return lax_numpy._rewriting_take(self, idx)
316 else:
--> 317 return lax_numpy._rewriting_take(self, idx)
File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4142, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
4136 if (isinstance(aval, core.DShapedArray) and aval.shape == () and
4137 dtypes.issubdtype(aval.dtype, np.integer) and
4138 not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
4139 isinstance(arr.shape[0], int)):
4140 return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
-> 4142 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
4143 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
4144 unique_indices, mode, fill_value)
File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4220, in _split_index_for_jit(idx, shape)
4216 raise TypeError(f"JAX does not support string indexing; got {idx=}")
4218 # Expand any (concrete) boolean indices. We can then use advanced integer
4219 # indexing logic to handle them.
-> 4220 idx = _expand_bool_indices(idx, shape)
4222 leaves, treedef = tree_flatten(idx)
4223 dynamic = [None] * len(leaves)
File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4542, in _expand_bool_indices(idx, shape)
4540 expected_shape = shape[start: start + _ndim(i)]
4541 if i_shape != expected_shape:
-> 4542 raise IndexError("boolean index did not match shape of indexed array in index "
4543 f"{dim_number}: got {i_shape}, expected {expected_shape}")
4544 out.extend(np.where(i))
4545 else:
IndexError: boolean index did not match shape of indexed array in index 0: got (1,), expected ()
Does anyone have an idea how to fix it?
not sure, but one suggestion would be to implement this without Jax since it's such a simple penalty
Looks like the derivative decorator trips over the fact that the optimizable only has one dof. The Jax internals do not make it very clear, but check if current.dofs.dep_opts()[0].dofs_free_status is a boolean instead of a length one array thereof