tidy3d
tidy3d copied to clipboard
Vector-valued objective functions
I am using tidy3D to perform multi-objectives inverse design. To speed things up, I'd like to run batches with web.run_async and return the different objectives as a vector value (I need the different individual values to define multiple constraints to y optimisation problem).
To do that, I can't use jax.value_and_grad (which is limited to scalar functions). Instead, I need to use jax.jacrev or jax.jacfwd.
But when I do that, I run into ConcretizationTypeError if I'm using a FieldMonitor, or a TracerArrayConversionError if I'm using a ModeMonitor.
A simple way to reproduce the problem is to extend the tutorial at https://www.flexcompute.com/tidy3d/examples/notebooks/AdjointPlugin1Intro/ , with
jac = jax.jacrev(power, argnums=(0,1,2))
d_power = jac(center0, size0, eps0)
This is not a vector-valued objective function, but the problem is the same:
Traceback (most recent call last):
File ~/micromamba/envs/tidy3d/lib/python3.10/runpy.py:196 in _run_module_as_main
return _run_code(code, main_globals, None,
File ~/micromamba/envs/tidy3d/lib/python3.10/runpy.py:86 in _run_code
exec(code, run_globals)
File ~/.local/lib/python3.10/site-packages/spyder_kernels/console/__main__.py:24
start.main()
File ~/.local/lib/python3.10/site-packages/spyder_kernels/console/start.py:340 in main
kernel.start()
File ~/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py:724 in start
self.io_loop.start()
File ~/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py:215 in start
self.asyncio_loop.run_forever()
File ~/micromamba/envs/tidy3d/lib/python3.10/asyncio/base_events.py:595 in run_forever
self._run_once()
File ~/micromamba/envs/tidy3d/lib/python3.10/asyncio/base_events.py:1881 in _run_once
handle._run()
File ~/micromamba/envs/tidy3d/lib/python3.10/asyncio/events.py:80 in _run
self._context.run(self._callback, *self._args)
File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:512 in dispatch_queue
await self.process_one()
File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:501 in process_one
await dispatch(*args)
File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:408 in dispatch_shell
await result
File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:731 in execute_request
reply_content = await reply_content
File ~/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py:417 in do_execute
res = shell.run_cell(
File ~/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py:540 in run_cell
return super().run_cell(*args, **kwargs)
File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2945 in run_cell
result = self._run_cell(
File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3000 in _run_cell
return runner(coro)
File ~/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py:129 in _pseudo_sync_runner
coro.send(None)
File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3203 in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3382 in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3442 in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
Cell In[7], line 2
d_power = jac(center0, size0, eps0)
File ~/rare_earth_ions/simulation/problem_jacobian.py:100 in power
jax_sim_data = run_adjoint(jax_sim, task_name="adjoint_power", verbose=True)
JaxStackTraceBeforeTransformation: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex64[1,1].
This BatchTracer with object id 140334367546704 was created on line:
/home/pr/rare_earth_ions/simulation/problem_jacobian.py:94 (compute_power)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
Cell In[7], line 2
d_power = jac(center0, size0, eps0)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/api.py:951 in jacfun
jac = vmap(pullback)(_std_basis(y))
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/traceback_util.py:166 in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/api.py:1258 in vmap_f
out_flat = batching.batch(
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/linear_util.py:188 in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/tree_util.py:361 in __call__
return self.fun(*args, **kw)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/api.py:2161 in _vjp_pullback_wrapper
ans = fun(*args)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/tree_util.py:361 in __call__
return self.fun(*args, **kw)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:147 in unbound_vjp
arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:254 in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:761 in _custom_lin_transpose
cts_in = bwd(*res, *cts_out)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/custom_derivatives.py:769 in <lambda>
bwd_ = lambda *args: bwd(*args)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/linear_util.py:188 in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/web.py:169 in run_bwd
jax_sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/sim_data.py:175 in make_adjoint_simulation
for adj_source in mnt_data_vjp.to_adjoint_sources(fwidth=fwidth):
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/monitor_data.py:83 in to_adjoint_sources
amps, sel_coords = self.amps.nonzero_val_coords
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/components/base.py:51 in cached_property_getter
computed_value = prop(self)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/data_array.py:448 in nonzero_val_coords
values = np.nan_to_num(self.as_ndarray)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/components/base.py:51 in cached_property_getter
computed_value = prop(self)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/data_array.py:131 in as_ndarray
return np.array(self.values)
File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/core.py:611 in __array__
raise TracerArrayConversionError(self)
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex64[1,1].
This BatchTracer with object id 140334367546704 was created on line:
/home/pr/rare_earth_ions/simulation/problem_jacobian.py:94 (compute_power)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Hi @remypa ,
We haven't intended to support this, nor have we tested it before. However, we can put this on our roadmap for future versions. It's possible some of my changes in #1551 may make it work, but it's unlikely.
What I don't understand exactly is why you'd need to do this, perhaps you can explain. My understanding is you have a set of objective function f_i
, each of which depends on a tidy3d simulation.
Typically you'd combine all of these in a single objective function, eg. by summing over f_i
. In that case you could still use value_and_grad
. However, you'd like to store the values of f_i
and then do some additional processing with them? Can you explain that a bit more?
Ultimately if you are able to combine everything you'd like to do in a single objective function, it should still be possible to use value_and_grad
.
Note that you can always use has_aux
in the value_and_grad
call if you simply need to store these f_i
values to process outside of the loop. I'd recommend seeing this tutorial, cells [17][18], for an example.
Hi,
thanks a lot for your reply.
I am working on a minimax-type of problem :
\min_{x \in \mathbb{R}^{n}} \max(f_1(x), f_2(x), .... f_m(x))
which I reformulate as
\displaylines{\min_{x \in \mathbb{R}^{n}, t \in \mathbb{R}} t \\
\text{s.t } t \ge f_k(x) \text{ for } k=1, 2, ...., m}
as per https://nlopt.readthedocs.io/en/latest/NLopt_Introduction/#equivalent-formulations-of-optimization-problems.
Which is why I don't really recombine the f_i
into a single objective function.
I see. Yea we will have to work on improving the compatibility with tidy3d adjoint and nlopt for these sorts of problems.
In the meantime, I might suggest you use a softmax
function such as jnp.nn.softmax
. You can use this to weight your f_i
to preferentially penalize the maximum one, such that your objective function is still differentiable.
Some pseudo-code below but double check the specifics.
def objective(x):
fs = jnp.array([f(x, i) for i in range(m)])
weights = jnp.nn.softmax(fs)
return jnp.sum(weights * fs)
My intuition tells me that this should work reasonably well without needing to transform the problem to constraints.
EDIT: forgot to jnp.sum
in the return. fixed.
The direct use of jax.jacrev(vector_valued_fn)
requires that all operations in vector_valued_fn()
have batching rules defined. If you look at the implementation of jax.jacrev
, you'll see that it just vmaps over jax.vjp
. Generally, this means that one needs to define JAX primitives with batching rules, which is different from the strategy of defining jax.custom_vjp
rules that wrap non-JAX code (Tidy3D's approach). There is no way to avoid leaking JAX types into the wrapped code when higher-order JAX transformations are used on a custom_vjp
, which is why you see the error in the OP.
You don't need to use jax.jacrev
, and it might even be less convenient since it does not return the vector value (only the Jacobian). You can instead manually manage the construction of the constraint vector Jacobian using a Python loop (or whatever batching Tidy3D provides). This would allow you to to perform the epigraph minimax style of optimization that has been popularized by Meep.
Hey @remypa,
to add to the above - since everything in nlopt needs to happen outside of JAX anyway, it is perfectly fine to construct the constraint vector as @ianwilliamson described. In the simplest case (single wavelength, only differentiating w.r.t. a single argument), that would look something like this:
def nlopt_epigraph_constraint(result: np.ndarray, x: np.ndarray, gd: np.ndarray) -> None:
t, v = x[0], x[1:]
# evaluate all objectives and get their gradients, assuming obj_fun_vgs is a list
# of gradient functions defined somewhere else of the form:
# d_obj = jax.value_and_grad(objective)
obj_vec, grad_vec = [], []
for obj_fun_vg in obj_fun_vgs:
obj_val, grad_val = obj_fun_vg(v)
obj_vec.append(obj_val)
grad_vec.append(grad_val)
if gd.size > 0:
gd[:, 0] = -1
gd[:, 1:] = np.asarray(grad_vec)
result[:] = np.asarray(obj_vec) - t
You can also parallelize the evaluation of the objective functions in that loop with something like async/await
, or maybe it's possible via tidy3d's built-in batching, although I don't know.
On the one hand, it might be nice to support this out of the box, but it would really just mean moving the bookkeeping (assembling the constraint vector) to somewhere inside tidy3d's adjoint module. How that needs to be handled exactly depends on the optimization package. For example, how nlopt handles this might differ from scipy or IPOPT, so there is not a general solution there.
Hi,
thanks for all your inputs/suggestions.
@ianwilliamson , @yaugenst : that is the way I am doing it at the moment. It does work, but the difficulty is that I think I can't use Tidy3D's batch infrastructure out of the box, hence my initial post.
@tylerflex : I have started looking into softmax. My initial results are promising.
Cheers.