DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Error in `map_coordinates` with `jax==0.4.34`

Open dpanici opened this issue 4 months ago • 2 comments

Calling, for example plot_comparison in our basic equilibriim notebook yields a JAX error on 0.4.34

related to _jacobi_jvp which seems like we did something slightly hacky to avoid escaped trace values, maybe not JAX does not like that anymore

{
	"name": "TypeError",
	"message": "Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.",
	"stack": "---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel_launcher.py:18
     16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/traitlets/config/application.py:1075, in launch_instance()
   1074 app.initialize(argv)
-> 1075 app.start()

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelapp.py:739, in start()
    738 try:
--> 739     self.io_loop.start()
    740 except KeyboardInterrupt:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/tornado/platform/asyncio.py:205, in start()
    204 def start(self) -> None:
--> 205     self.asyncio_loop.run_forever()

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/base_events.py:618, in run_forever()
    617 while True:
--> 618     self._run_once()
    619     if self._stopping:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/base_events.py:1951, in _run_once()
   1950     else:
-> 1951         handle._run()
   1952 handle = None

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/events.py:84, in _run()
     83 try:
---> 84     self._context.run(self._callback, *self._args)
     85 except (SystemExit, KeyboardInterrupt):

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:545, in dispatch_queue()
    544 try:
--> 545     await self.process_one()
    546 except Exception:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:534, in process_one()
    533         return
--> 534 await dispatch(*args)

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:437, in dispatch_shell()
    436     if inspect.isawaitable(result):
--> 437         await result
    438 except Exception:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/ipkernel.py:362, in execute_request()
    361 self._associate_new_top_level_threads_with(parent_header)
--> 362 await super().execute_request(stream, ident, parent)

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:778, in execute_request()
    777 if inspect.isawaitable(reply_content):
--> 778     reply_content = await reply_content
    780 # Flush output before sending the reply.

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/ipkernel.py:449, in do_execute()
    448 if accepts_params[\"cell_id\"]:
--> 449     res = shell.run_cell(
    450         code,
    451         store_history=store_history,
    452         silent=silent,
    453         cell_id=cell_id,
    454     )
    455 else:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/zmqshell.py:549, in run_cell()
    548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3075, in run_cell()
   3074 try:
-> 3075     result = self._run_cell(
   3076         raw_cell, store_history, silent, shell_futures, cell_id
   3077     )
   3078 finally:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3130, in _run_cell()
   3129 try:
-> 3130     result = runner(coro)
   3131 except BaseException as e:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/async_helpers.py:128, in _pseudo_sync_runner()
    127 try:
--> 128     coro.send(None)
    129 except StopIteration as exc:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3334, in run_cell_async()
   3331 interactivity = \"none\" if silent else self.ast_node_interactivity
-> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3335        interactivity=interactivity, compiler=compiler, result=result)
   3337 self.last_execution_succeeded = not has_raised

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3517, in run_ast_nodes()
   3516     asy = compare(code)
-> 3517 if await self.run_code(code, result, async_=asy):
   3518     return True

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3577, in run_code()
   3576     else:
-> 3577         exec(code_obj, self.user_global_ns, self.user_ns)
   3578 finally:
   3579     # Reset our crash handler in place

Cell In[6], line 3
      2 print(\"Number of equilibria in the EquilibriaFamily:\", len(eq_fam))
----> 3 fig, ax = plot_comparison(
      4     eqs=[eq_fam[1], eq_fam[3], eq_fam[-1]],
      5     labels=[
      6         \"Axisymmetric w/o pressure\",
      7         \"Axisymmetric w/ pressure\",
      8         \"Nonaxisymmetric w/ pressure\",
      9     ],
     10 )

File ~/Research/DESC/desc/plotting.py:2388, in plot_comparison()
   2380 for i, eq in enumerate(eqs):
   2381     fig, ax, _plot_data = plot_surfaces(
   2382         eq,
   2383         rho,
   2384         theta,
   2385         phi,
   2386         ax,
   2387         theta_color=color[i % len(color)],
-> 2388         theta_ls=ls[i % len(ls)],
   2389         theta_lw=lw[i % len(lw)],
   2390         rho_color=color[i % len(color)],
   2391         rho_ls=ls[i % len(ls)],
   2392         rho_lw=lw[i % len(lw)],
   2393         lcfs_color=color[i % len(color)],
   2394         lcfs_ls=ls[i % len(ls)],
   2395         lcfs_lw=lw[i % len(lw)],
   2396         axis_color=color[i % len(color)],
   2397         axis_alpha=0,
   2398         axis_marker=\"o\",
   2399         axis_size=0,
   2400         label=labels[i % len(labels)],
   2401         title_fontsize=title_fontsize,
   2402         xlabel_fontsize=xlabel_fontsize,
   2403         ylabel_fontsize=ylabel_fontsize,
   2404         return_data=True,
   2405     )
   2406     for key in _plot_data.keys():

File ~/Research/DESC/desc/plotting.py:1685, in plot_surfaces()
   1676     tnr, tnt, tnz = t_grid.num_rho, t_grid.num_theta, t_grid.num_zeta
   1677     v_grid = Grid(
   1678         map_coordinates(
   1679             eq,
   1680             t_grid.nodes,
   1681             [\"rho\", \"theta_PEST\", \"phi\"],
   1682             [\"rho\", \"theta\", \"zeta\"],
   1683             period=(np.inf, 2 * np.pi, 2 * np.pi),
   1684             guess=t_grid.nodes,
-> 1685         ),
   1686         sort=False,
   1687     )
   1688 rows = np.floor(np.sqrt(nphi)).astype(int)

File ~/Research/DESC/desc/equilibrium/coords.py:218, in map_coordinates()
    215 # See description here
    216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
    217 # except we make sure properly handle periodic coordinates.
--> 218 yk, (res, niter) = vecroot(yk, coords)
    220 out = compute(yk, outbasis)

File ~/Research/DESC/desc/equilibrium/coords.py:203, in map_coordinates.<locals>.<lambda>()
    199 yk = fixup(yk)
    201 vecroot = jit(
    202     vmap(
--> 203         lambda x0, *p: root(
    204             residual,
    205             x0,
    206             jac=jac,
    207             args=p,
    208             fixup=fixup,
    209             tol=tol,
    210             maxiter=maxiter,
    211             **kwargs,
    212         )
    213     )
    214 )
    215 # See description here
    216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
    217 # except we make sure properly handle periodic coordinates.

File ~/Research/DESC/desc/backend.py:398, in root()
    396     return _lstsq(A, jnp.atleast_1d(y))
--> 398 x, (res, niter) = jax.lax.custom_root(
    399     res, x0, solve, tangent_solve, has_aux=True
    400 )
    401 return x, (safenorm(res), niter)

File ~/Research/DESC/desc/backend.py:344, in root.<locals>.<lambda>()
    342     jac2 = lambda x: jnp.atleast_2d(jac(x, *args))
--> 344 res = lambda x: jnp.atleast_1d(fun(x, *args)).flatten()
    346 # want to use least squares for rank-defficient systems, but
    347 # jnp.linalg.lstsq doesn't have JVP defined and is slower than needed
    348 # so we use the normal equations with regularized cholesky

File ~/Research/DESC/desc/equilibrium/coords.py:174, in residual()
    172 @jit
    173 def residual(y, coords):
--> 174     xk = compute(y, inbasis)
    175     return _fixup_residual(xk - coords, period)

File ~/Research/DESC/desc/equilibrium/coords.py:167, in compute()
    166     data[\"iota_rr\"] = profiles[\"iota\"].compute(grid, dr=2, params=params[\"i_l\"])
--> 167 transforms = get_transforms(basis, eq, grid, jitable=True)
    168 data = compute_fun(eq, basis, params, transforms, profiles, data)

File ~/Research/DESC/desc/backend.py:112, in wrapper()
    111 with jax.default_device(jax.devices(\"cpu\")[0]):
--> 112     return func(*args, **kwargs)

File ~/Research/DESC/desc/compute/utils.py:631, in get_transforms()
    630     if hasattr(t, \"build\"):
--> 631         t.build()
    633 return transforms

File ~/Research/DESC/desc/transform.py:385, in build()
    384     for d in self.derivatives:
--> 385         self.matrices[\"direct1\"][d[0]][d[1]][d[2]] = self.basis.evaluate(
    386             self.grid.nodes, d, unique=False
    387         )
    389 if self.method in [\"fft\", \"direct2\"]:

File ~/Research/DESC/desc/basis.py:1134, in evaluate()
   1132     n = n[nidx]
-> 1134 radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0])
   1135 poloidal = fourier(t[:, np.newaxis], m, dt=derivatives[1])

File ~/Research/DESC/desc/basis.py:1533, in zernike_radial()
   1532 if dr == 0:
-> 1533     out = r**m * _jacobi(n, alpha, beta, jacobi_arg, 0)
   1534 elif dr == 1:

JaxStackTraceBeforeTransformation: TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[6], line 3
      1 eq_fam = desc.io.load(\"input.HELIOTRON_output.h5\")
      2 print(\"Number of equilibria in the EquilibriaFamily:\", len(eq_fam))
----> 3 fig, ax = plot_comparison(
      4     eqs=[eq_fam[1], eq_fam[3], eq_fam[-1]],
      5     labels=[
      6         \"Axisymmetric w/o pressure\",
      7         \"Axisymmetric w/ pressure\",
      8         \"Nonaxisymmetric w/ pressure\",
      9     ],
     10 )

File ~/Research/DESC/desc/plotting.py:2388, in plot_comparison(eqs, rho, theta, phi, ax, cmap, color, lw, ls, labels, return_data, **kwargs)
   2386     plot_data[string] = []
   2387 for i, eq in enumerate(eqs):
-> 2388     fig, ax, _plot_data = plot_surfaces(
   2389         eq,
   2390         rho,
   2391         theta,
   2392         phi,
   2393         ax,
   2394         theta_color=color[i % len(color)],
   2395         theta_ls=ls[i % len(ls)],
   2396         theta_lw=lw[i % len(lw)],
   2397         rho_color=color[i % len(color)],
   2398         rho_ls=ls[i % len(ls)],
   2399         rho_lw=lw[i % len(lw)],
   2400         lcfs_color=color[i % len(color)],
   2401         lcfs_ls=ls[i % len(ls)],
   2402         lcfs_lw=lw[i % len(lw)],
   2403         axis_color=color[i % len(color)],
   2404         axis_alpha=0,
   2405         axis_marker=\"o\",
   2406         axis_size=0,
   2407         label=labels[i % len(labels)],
   2408         title_fontsize=title_fontsize,
   2409         xlabel_fontsize=xlabel_fontsize,
   2410         ylabel_fontsize=ylabel_fontsize,
   2411         return_data=True,
   2412     )
   2413     for key in _plot_data.keys():
   2414         plot_data[key].append(_plot_data[key])

File ~/Research/DESC/desc/plotting.py:1685, in plot_surfaces(eq, rho, theta, phi, ax, return_data, **kwargs)
   1682     t_grid = _get_grid(**grid_kwargs)
   1683     tnr, tnt, tnz = t_grid.num_rho, t_grid.num_theta, t_grid.num_zeta
   1684     v_grid = Grid(
-> 1685         map_coordinates(
   1686             eq,
   1687             t_grid.nodes,
   1688             [\"rho\", \"theta_PEST\", \"phi\"],
   1689             [\"rho\", \"theta\", \"zeta\"],
   1690             period=(np.inf, 2 * np.pi, 2 * np.pi),
   1691             guess=t_grid.nodes,
   1692         ),
   1693         sort=False,
   1694     )
   1695 rows = np.floor(np.sqrt(nphi)).astype(int)
   1696 cols = np.ceil(nphi / rows).astype(int)

File ~/Research/DESC/desc/equilibrium/coords.py:218, in map_coordinates(eq, coords, inbasis, outbasis, guess, params, period, tol, maxiter, full_output, **kwargs)
    201 vecroot = jit(
    202     vmap(
    203         lambda x0, *p: root(
   (...)
    213     )
    214 )
    215 # See description here
    216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
    217 # except we make sure properly handle periodic coordinates.
--> 218 yk, (res, niter) = vecroot(yk, coords)
    220 out = compute(yk, outbasis)
    221 if full_output:

    [... skipping hidden 14 frame]

File ~/Research/DESC/desc/equilibrium/coords.py:203, in map_coordinates.<locals>.<lambda>(x0, *p)
    197     yk = _initial_guess_nn_search(coords, inbasis, eq, period, compute)
    199 yk = fixup(yk)
    201 vecroot = jit(
    202     vmap(
--> 203         lambda x0, *p: root(
    204             residual,
    205             x0,
    206             jac=jac,
    207             args=p,
    208             fixup=fixup,
    209             tol=tol,
    210             maxiter=maxiter,
    211             **kwargs,
    212         )
    213     )
    214 )
    215 # See description here
    216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
    217 # except we make sure properly handle periodic coordinates.
    218 yk, (res, niter) = vecroot(yk, coords)

File ~/Research/DESC/desc/backend.py:398, in root(fun, x0, jac, args, tol, maxiter, maxiter_ls, alpha, fixup)
    395     A = jnp.atleast_2d(jax.jacfwd(g)(y))
    396     return _lstsq(A, jnp.atleast_1d(y))
--> 398 x, (res, niter) = jax.lax.custom_root(
    399     res, x0, solve, tangent_solve, has_aux=True
    400 )
    401 return x, (safenorm(res), niter)

    [... skipping hidden 14 frame]

File ~/Research/DESC/desc/backend.py:344, in root.<locals>.<lambda>(x)
    341 else:
    342     jac2 = lambda x: jnp.atleast_2d(jac(x, *args))
--> 344 res = lambda x: jnp.atleast_1d(fun(x, *args)).flatten()
    346 # want to use least squares for rank-defficient systems, but
    347 # jnp.linalg.lstsq doesn't have JVP defined and is slower than needed
    348 # so we use the normal equations with regularized cholesky
    349 def _lstsq(a, b):

    [... skipping hidden 47 frame]

File ~/Research/DESC/desc/basis.py:1816, in _jacobi_jvp(x, xdot)
   1811 df = _jacobi(n, alpha, beta, x, dx + 1)
   1812 # in theory n, alpha, beta, dx aren't differentiable (they're integers)
   1813 # but marking them as non-diff argnums seems to cause escaped tracer values.
   1814 # probably a more elegant fix, but just setting those derivatives to zero seems
   1815 # to work fine.
-> 1816 return f, df * xdot + 0 * ndot + 0 * alphadot + 0 * betadot + 0 * dxdot

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:1036, in _forward_operator_to_aval.<locals>.op(self, *args)
   1035 def op(self, *args):
-> 1036   return getattr(self.aval, f\"_{name}\")(self, *args)

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:573, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    571 args = (other, self) if swap else (self, other)
    572 if isinstance(other, _accepted_binop_types):
--> 573   return binary_op(*args)
    574 # Note: don't use isinstance here, because we don't want to raise for
    575 # subclasses, e.g. NamedTuple objects that may override operators.
    576 if type(other) in _rejected_binop_types:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/ufunc_api.py:177, in ufunc.__call__(self, out, where, *args)
    175   raise NotImplementedError(f\"where argument of {self}\")
    176 call = self.__static_props['call'] or self._call_vectorized
--> 177 return call(*args)

    [... skipping hidden 11 frame]

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py:1142, in _multiply(x, y)
   1115 @partial(jit, inline=True)
   1116 def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
   1117   \"\"\"Multiply two arrays element-wise.
   1118 
   1119   JAX implementation of :obj:`numpy.multiply`. This is a universal function,
   (...)
   1140     Array([ 0, 10, 20, 30], dtype=int32)
   1141   \"\"\"
-> 1142   x, y = promote_args(\"multiply\", x, y)
   1143   return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/util.py:355, in promote_args(fun_name, *args)
    353 \"\"\"Convenience function to apply Numpy argument shape and dtype promotion.\"\"\"
    354 check_arraylike(fun_name, *args)
--> 355 _check_no_float0s(fun_name, *args)
    356 check_for_prngkeys(fun_name, *args)
    357 return promote_shapes(fun_name, *promote_dtypes(*args))

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/util.py:326, in check_no_float0s(fun_name, *args)
    324 \"\"\"Check if none of the args have dtype float0.\"\"\"
    325 if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
--> 326   raise TypeError(
    327       f\"Called {fun_name} with a float0 array. \"
    328       \"float0s do not support any operations by design because they \"
    329       \"are not compatible with non-trivial vector spaces. No implicit dtype \"
    330       \"conversion is done. You can use np.zeros_like(arr, dtype=np.float) \"
    331       \"to cast a float0 array to a regular zeros array. \
\"
    332       \"If you didn't expect to get a float0 you might have accidentally \"
    333       \"taken a gradient with respect to an integer argument.\")

TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument."
}

dpanici avatar Oct 04 '24 16:10 dpanici