tidy3d icon indicating copy to clipboard operation
tidy3d copied to clipboard

Vector-valued objective functions

Open remypa opened this issue 11 months ago • 7 comments

I am using tidy3D to perform multi-objectives inverse design. To speed things up, I'd like to run batches with web.run_async and return the different objectives as a vector value (I need the different individual values to define multiple constraints to y optimisation problem).

To do that, I can't use jax.value_and_grad (which is limited to scalar functions). Instead, I need to use jax.jacrev or jax.jacfwd.

But when I do that, I run into ConcretizationTypeError if I'm using a FieldMonitor, or a TracerArrayConversionError if I'm using a ModeMonitor.

A simple way to reproduce the problem is to extend the tutorial at https://www.flexcompute.com/tidy3d/examples/notebooks/AdjointPlugin1Intro/ , with

jac = jax.jacrev(power, argnums=(0,1,2))
d_power = jac(center0, size0, eps0)

This is not a vector-valued objective function, but the problem is the same:

Traceback (most recent call last):

  File ~/micromamba/envs/tidy3d/lib/python3.10/runpy.py:196 in _run_module_as_main
    return _run_code(code, main_globals, None,

  File ~/micromamba/envs/tidy3d/lib/python3.10/runpy.py:86 in _run_code
    exec(code, run_globals)

  File ~/.local/lib/python3.10/site-packages/spyder_kernels/console/__main__.py:24
    start.main()

  File ~/.local/lib/python3.10/site-packages/spyder_kernels/console/start.py:340 in main
    kernel.start()

  File ~/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py:724 in start
    self.io_loop.start()

  File ~/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py:215 in start
    self.asyncio_loop.run_forever()

  File ~/micromamba/envs/tidy3d/lib/python3.10/asyncio/base_events.py:595 in run_forever
    self._run_once()

  File ~/micromamba/envs/tidy3d/lib/python3.10/asyncio/base_events.py:1881 in _run_once
    handle._run()

  File ~/micromamba/envs/tidy3d/lib/python3.10/asyncio/events.py:80 in _run
    self._context.run(self._callback, *self._args)

  File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:512 in dispatch_queue
    await self.process_one()

  File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:501 in process_one
    await dispatch(*args)

  File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:408 in dispatch_shell
    await result

  File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:731 in execute_request
    reply_content = await reply_content

  File ~/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py:417 in do_execute
    res = shell.run_cell(

  File ~/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py:540 in run_cell
    return super().run_cell(*args, **kwargs)

  File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2945 in run_cell
    result = self._run_cell(

  File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3000 in _run_cell
    return runner(coro)

  File ~/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py:129 in _pseudo_sync_runner
    coro.send(None)

  File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3203 in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,

  File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3382 in run_ast_nodes
    if await self.run_code(code, result, async_=asy):

  File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3442 in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)

  Cell In[7], line 2
    d_power = jac(center0, size0, eps0)

  File ~/rare_earth_ions/simulation/problem_jacobian.py:100 in power
    jax_sim_data = run_adjoint(jax_sim, task_name="adjoint_power", verbose=True)

JaxStackTraceBeforeTransformation: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex64[1,1].
This BatchTracer with object id 140334367546704 was created on line:
  /home/pr/rare_earth_ions/simulation/problem_jacobian.py:94 (compute_power)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------


The above exception was the direct cause of the following exception:

Traceback (most recent call last):

  Cell In[7], line 2
    d_power = jac(center0, size0, eps0)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/api.py:951 in jacfun
    jac = vmap(pullback)(_std_basis(y))

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/traceback_util.py:166 in reraise_with_filtered_traceback
    return fun(*args, **kwargs)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/api.py:1258 in vmap_f
    out_flat = batching.batch(

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/linear_util.py:188 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/tree_util.py:361 in __call__
    return self.fun(*args, **kw)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/api.py:2161 in _vjp_pullback_wrapper
    ans = fun(*args)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/tree_util.py:361 in __call__
    return self.fun(*args, **kw)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:147 in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:254 in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:761 in _custom_lin_transpose
    cts_in = bwd(*res, *cts_out)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/custom_derivatives.py:769 in <lambda>
    bwd_ = lambda *args: bwd(*args)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/linear_util.py:188 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/web.py:169 in run_bwd
    jax_sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/sim_data.py:175 in make_adjoint_simulation
    for adj_source in mnt_data_vjp.to_adjoint_sources(fwidth=fwidth):

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/monitor_data.py:83 in to_adjoint_sources
    amps, sel_coords = self.amps.nonzero_val_coords

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/components/base.py:51 in cached_property_getter
    computed_value = prop(self)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/data_array.py:448 in nonzero_val_coords
    values = np.nan_to_num(self.as_ndarray)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/components/base.py:51 in cached_property_getter
    computed_value = prop(self)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/data_array.py:131 in as_ndarray
    return np.array(self.values)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/core.py:611 in __array__
    raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex64[1,1].
This BatchTracer with object id 140334367546704 was created on line:
  /home/pr/rare_earth_ions/simulation/problem_jacobian.py:94 (compute_power)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

remypa avatar Mar 17 '24 16:03 remypa

Hi @remypa ,

We haven't intended to support this, nor have we tested it before. However, we can put this on our roadmap for future versions. It's possible some of my changes in #1551 may make it work, but it's unlikely.

What I don't understand exactly is why you'd need to do this, perhaps you can explain. My understanding is you have a set of objective function f_i, each of which depends on a tidy3d simulation.

Typically you'd combine all of these in a single objective function, eg. by summing over f_i. In that case you could still use value_and_grad. However, you'd like to store the values of f_i and then do some additional processing with them? Can you explain that a bit more?

Ultimately if you are able to combine everything you'd like to do in a single objective function, it should still be possible to use value_and_grad.

Note that you can always use has_aux in the value_and_grad call if you simply need to store these f_i values to process outside of the loop. I'd recommend seeing this tutorial, cells [17][18], for an example.

tylerflex avatar Mar 18 '24 13:03 tylerflex

Hi,

thanks a lot for your reply.

I am working on a minimax-type of problem :

\min_{x \in \mathbb{R}^{n}}  \max(f_1(x), f_2(x), .... f_m(x))

which I reformulate as

\displaylines{\min_{x \in \mathbb{R}^{n}, t \in \mathbb{R}} t \\
\text{s.t }  t \ge f_k(x) \text{ for } k=1, 2, ...., m} 

as per https://nlopt.readthedocs.io/en/latest/NLopt_Introduction/#equivalent-formulations-of-optimization-problems.

Which is why I don't really recombine the f_i into a single objective function.

remypa avatar Mar 18 '24 14:03 remypa

I see. Yea we will have to work on improving the compatibility with tidy3d adjoint and nlopt for these sorts of problems.

In the meantime, I might suggest you use a softmax function such as jnp.nn.softmax. You can use this to weight your f_i to preferentially penalize the maximum one, such that your objective function is still differentiable.

Some pseudo-code below but double check the specifics.

def objective(x):
    fs = jnp.array([f(x, i) for i in range(m)])
    weights = jnp.nn.softmax(fs)
    return jnp.sum(weights * fs)

My intuition tells me that this should work reasonably well without needing to transform the problem to constraints.

tylerflex avatar Mar 18 '24 15:03 tylerflex

EDIT: forgot to jnp.sum in the return. fixed.

tylerflex avatar Mar 18 '24 15:03 tylerflex

The direct use of jax.jacrev(vector_valued_fn) requires that all operations in vector_valued_fn() have batching rules defined. If you look at the implementation of jax.jacrev, you'll see that it just vmaps over jax.vjp. Generally, this means that one needs to define JAX primitives with batching rules, which is different from the strategy of defining jax.custom_vjp rules that wrap non-JAX code (Tidy3D's approach). There is no way to avoid leaking JAX types into the wrapped code when higher-order JAX transformations are used on a custom_vjp, which is why you see the error in the OP.

You don't need to use jax.jacrev, and it might even be less convenient since it does not return the vector value (only the Jacobian). You can instead manually manage the construction of the constraint vector Jacobian using a Python loop (or whatever batching Tidy3D provides). This would allow you to to perform the epigraph minimax style of optimization that has been popularized by Meep.

ianwilliamson avatar Mar 18 '24 17:03 ianwilliamson

Hey @remypa,

to add to the above - since everything in nlopt needs to happen outside of JAX anyway, it is perfectly fine to construct the constraint vector as @ianwilliamson described. In the simplest case (single wavelength, only differentiating w.r.t. a single argument), that would look something like this:

def nlopt_epigraph_constraint(result: np.ndarray, x: np.ndarray, gd: np.ndarray) -> None:
        t, v = x[0], x[1:]

        # evaluate all objectives and get their gradients, assuming obj_fun_vgs is a list
        # of gradient functions defined somewhere else of the form:
        # d_obj = jax.value_and_grad(objective)
        obj_vec, grad_vec = [], []
        for obj_fun_vg in obj_fun_vgs:
                obj_val, grad_val = obj_fun_vg(v)
                obj_vec.append(obj_val)
                grad_vec.append(grad_val)

        if gd.size > 0:
            gd[:, 0] = -1
            gd[:, 1:] = np.asarray(grad_vec)

        result[:] = np.asarray(obj_vec) - t

You can also parallelize the evaluation of the objective functions in that loop with something like async/await, or maybe it's possible via tidy3d's built-in batching, although I don't know. On the one hand, it might be nice to support this out of the box, but it would really just mean moving the bookkeeping (assembling the constraint vector) to somewhere inside tidy3d's adjoint module. How that needs to be handled exactly depends on the optimization package. For example, how nlopt handles this might differ from scipy or IPOPT, so there is not a general solution there.

yaugenst avatar Mar 19 '24 09:03 yaugenst

Hi,

thanks for all your inputs/suggestions.

@ianwilliamson , @yaugenst : that is the way I am doing it at the moment. It does work, but the difficulty is that I think I can't use Tidy3D's batch infrastructure out of the box, hence my initial post.

@tylerflex : I have started looking into softmax. My initial results are promising.

Cheers.

remypa avatar Mar 19 '24 09:03 remypa