DESC
DESC copied to clipboard
Error in `map_coordinates` with `jax==0.4.34`
Calling, for example plot_comparison in our basic equilibriim notebook yields a JAX error on 0.4.34
related to _jacobi_jvp
which seems like we did something slightly hacky to avoid escaped trace values, maybe not JAX does not like that anymore
{
"name": "TypeError",
"message": "Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array.
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.",
"stack": "---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()
File <frozen runpy>:88, in _run_code()
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel_launcher.py:18
16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/traitlets/config/application.py:1075, in launch_instance()
1074 app.initialize(argv)
-> 1075 app.start()
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelapp.py:739, in start()
738 try:
--> 739 self.io_loop.start()
740 except KeyboardInterrupt:
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/tornado/platform/asyncio.py:205, in start()
204 def start(self) -> None:
--> 205 self.asyncio_loop.run_forever()
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/base_events.py:618, in run_forever()
617 while True:
--> 618 self._run_once()
619 if self._stopping:
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/base_events.py:1951, in _run_once()
1950 else:
-> 1951 handle._run()
1952 handle = None
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/events.py:84, in _run()
83 try:
---> 84 self._context.run(self._callback, *self._args)
85 except (SystemExit, KeyboardInterrupt):
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:545, in dispatch_queue()
544 try:
--> 545 await self.process_one()
546 except Exception:
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:534, in process_one()
533 return
--> 534 await dispatch(*args)
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:437, in dispatch_shell()
436 if inspect.isawaitable(result):
--> 437 await result
438 except Exception:
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/ipkernel.py:362, in execute_request()
361 self._associate_new_top_level_threads_with(parent_header)
--> 362 await super().execute_request(stream, ident, parent)
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:778, in execute_request()
777 if inspect.isawaitable(reply_content):
--> 778 reply_content = await reply_content
780 # Flush output before sending the reply.
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/ipkernel.py:449, in do_execute()
448 if accepts_params[\"cell_id\"]:
--> 449 res = shell.run_cell(
450 code,
451 store_history=store_history,
452 silent=silent,
453 cell_id=cell_id,
454 )
455 else:
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/zmqshell.py:549, in run_cell()
548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3075, in run_cell()
3074 try:
-> 3075 result = self._run_cell(
3076 raw_cell, store_history, silent, shell_futures, cell_id
3077 )
3078 finally:
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3130, in _run_cell()
3129 try:
-> 3130 result = runner(coro)
3131 except BaseException as e:
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/async_helpers.py:128, in _pseudo_sync_runner()
127 try:
--> 128 coro.send(None)
129 except StopIteration as exc:
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3334, in run_cell_async()
3331 interactivity = \"none\" if silent else self.ast_node_interactivity
-> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3335 interactivity=interactivity, compiler=compiler, result=result)
3337 self.last_execution_succeeded = not has_raised
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3517, in run_ast_nodes()
3516 asy = compare(code)
-> 3517 if await self.run_code(code, result, async_=asy):
3518 return True
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3577, in run_code()
3576 else:
-> 3577 exec(code_obj, self.user_global_ns, self.user_ns)
3578 finally:
3579 # Reset our crash handler in place
Cell In[6], line 3
2 print(\"Number of equilibria in the EquilibriaFamily:\", len(eq_fam))
----> 3 fig, ax = plot_comparison(
4 eqs=[eq_fam[1], eq_fam[3], eq_fam[-1]],
5 labels=[
6 \"Axisymmetric w/o pressure\",
7 \"Axisymmetric w/ pressure\",
8 \"Nonaxisymmetric w/ pressure\",
9 ],
10 )
File ~/Research/DESC/desc/plotting.py:2388, in plot_comparison()
2380 for i, eq in enumerate(eqs):
2381 fig, ax, _plot_data = plot_surfaces(
2382 eq,
2383 rho,
2384 theta,
2385 phi,
2386 ax,
2387 theta_color=color[i % len(color)],
-> 2388 theta_ls=ls[i % len(ls)],
2389 theta_lw=lw[i % len(lw)],
2390 rho_color=color[i % len(color)],
2391 rho_ls=ls[i % len(ls)],
2392 rho_lw=lw[i % len(lw)],
2393 lcfs_color=color[i % len(color)],
2394 lcfs_ls=ls[i % len(ls)],
2395 lcfs_lw=lw[i % len(lw)],
2396 axis_color=color[i % len(color)],
2397 axis_alpha=0,
2398 axis_marker=\"o\",
2399 axis_size=0,
2400 label=labels[i % len(labels)],
2401 title_fontsize=title_fontsize,
2402 xlabel_fontsize=xlabel_fontsize,
2403 ylabel_fontsize=ylabel_fontsize,
2404 return_data=True,
2405 )
2406 for key in _plot_data.keys():
File ~/Research/DESC/desc/plotting.py:1685, in plot_surfaces()
1676 tnr, tnt, tnz = t_grid.num_rho, t_grid.num_theta, t_grid.num_zeta
1677 v_grid = Grid(
1678 map_coordinates(
1679 eq,
1680 t_grid.nodes,
1681 [\"rho\", \"theta_PEST\", \"phi\"],
1682 [\"rho\", \"theta\", \"zeta\"],
1683 period=(np.inf, 2 * np.pi, 2 * np.pi),
1684 guess=t_grid.nodes,
-> 1685 ),
1686 sort=False,
1687 )
1688 rows = np.floor(np.sqrt(nphi)).astype(int)
File ~/Research/DESC/desc/equilibrium/coords.py:218, in map_coordinates()
215 # See description here
216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
217 # except we make sure properly handle periodic coordinates.
--> 218 yk, (res, niter) = vecroot(yk, coords)
220 out = compute(yk, outbasis)
File ~/Research/DESC/desc/equilibrium/coords.py:203, in map_coordinates.<locals>.<lambda>()
199 yk = fixup(yk)
201 vecroot = jit(
202 vmap(
--> 203 lambda x0, *p: root(
204 residual,
205 x0,
206 jac=jac,
207 args=p,
208 fixup=fixup,
209 tol=tol,
210 maxiter=maxiter,
211 **kwargs,
212 )
213 )
214 )
215 # See description here
216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
217 # except we make sure properly handle periodic coordinates.
File ~/Research/DESC/desc/backend.py:398, in root()
396 return _lstsq(A, jnp.atleast_1d(y))
--> 398 x, (res, niter) = jax.lax.custom_root(
399 res, x0, solve, tangent_solve, has_aux=True
400 )
401 return x, (safenorm(res), niter)
File ~/Research/DESC/desc/backend.py:344, in root.<locals>.<lambda>()
342 jac2 = lambda x: jnp.atleast_2d(jac(x, *args))
--> 344 res = lambda x: jnp.atleast_1d(fun(x, *args)).flatten()
346 # want to use least squares for rank-defficient systems, but
347 # jnp.linalg.lstsq doesn't have JVP defined and is slower than needed
348 # so we use the normal equations with regularized cholesky
File ~/Research/DESC/desc/equilibrium/coords.py:174, in residual()
172 @jit
173 def residual(y, coords):
--> 174 xk = compute(y, inbasis)
175 return _fixup_residual(xk - coords, period)
File ~/Research/DESC/desc/equilibrium/coords.py:167, in compute()
166 data[\"iota_rr\"] = profiles[\"iota\"].compute(grid, dr=2, params=params[\"i_l\"])
--> 167 transforms = get_transforms(basis, eq, grid, jitable=True)
168 data = compute_fun(eq, basis, params, transforms, profiles, data)
File ~/Research/DESC/desc/backend.py:112, in wrapper()
111 with jax.default_device(jax.devices(\"cpu\")[0]):
--> 112 return func(*args, **kwargs)
File ~/Research/DESC/desc/compute/utils.py:631, in get_transforms()
630 if hasattr(t, \"build\"):
--> 631 t.build()
633 return transforms
File ~/Research/DESC/desc/transform.py:385, in build()
384 for d in self.derivatives:
--> 385 self.matrices[\"direct1\"][d[0]][d[1]][d[2]] = self.basis.evaluate(
386 self.grid.nodes, d, unique=False
387 )
389 if self.method in [\"fft\", \"direct2\"]:
File ~/Research/DESC/desc/basis.py:1134, in evaluate()
1132 n = n[nidx]
-> 1134 radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0])
1135 poloidal = fourier(t[:, np.newaxis], m, dt=derivatives[1])
File ~/Research/DESC/desc/basis.py:1533, in zernike_radial()
1532 if dr == 0:
-> 1533 out = r**m * _jacobi(n, alpha, beta, jacobi_arg, 0)
1534 elif dr == 1:
JaxStackTraceBeforeTransformation: TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array.
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Cell In[6], line 3
1 eq_fam = desc.io.load(\"input.HELIOTRON_output.h5\")
2 print(\"Number of equilibria in the EquilibriaFamily:\", len(eq_fam))
----> 3 fig, ax = plot_comparison(
4 eqs=[eq_fam[1], eq_fam[3], eq_fam[-1]],
5 labels=[
6 \"Axisymmetric w/o pressure\",
7 \"Axisymmetric w/ pressure\",
8 \"Nonaxisymmetric w/ pressure\",
9 ],
10 )
File ~/Research/DESC/desc/plotting.py:2388, in plot_comparison(eqs, rho, theta, phi, ax, cmap, color, lw, ls, labels, return_data, **kwargs)
2386 plot_data[string] = []
2387 for i, eq in enumerate(eqs):
-> 2388 fig, ax, _plot_data = plot_surfaces(
2389 eq,
2390 rho,
2391 theta,
2392 phi,
2393 ax,
2394 theta_color=color[i % len(color)],
2395 theta_ls=ls[i % len(ls)],
2396 theta_lw=lw[i % len(lw)],
2397 rho_color=color[i % len(color)],
2398 rho_ls=ls[i % len(ls)],
2399 rho_lw=lw[i % len(lw)],
2400 lcfs_color=color[i % len(color)],
2401 lcfs_ls=ls[i % len(ls)],
2402 lcfs_lw=lw[i % len(lw)],
2403 axis_color=color[i % len(color)],
2404 axis_alpha=0,
2405 axis_marker=\"o\",
2406 axis_size=0,
2407 label=labels[i % len(labels)],
2408 title_fontsize=title_fontsize,
2409 xlabel_fontsize=xlabel_fontsize,
2410 ylabel_fontsize=ylabel_fontsize,
2411 return_data=True,
2412 )
2413 for key in _plot_data.keys():
2414 plot_data[key].append(_plot_data[key])
File ~/Research/DESC/desc/plotting.py:1685, in plot_surfaces(eq, rho, theta, phi, ax, return_data, **kwargs)
1682 t_grid = _get_grid(**grid_kwargs)
1683 tnr, tnt, tnz = t_grid.num_rho, t_grid.num_theta, t_grid.num_zeta
1684 v_grid = Grid(
-> 1685 map_coordinates(
1686 eq,
1687 t_grid.nodes,
1688 [\"rho\", \"theta_PEST\", \"phi\"],
1689 [\"rho\", \"theta\", \"zeta\"],
1690 period=(np.inf, 2 * np.pi, 2 * np.pi),
1691 guess=t_grid.nodes,
1692 ),
1693 sort=False,
1694 )
1695 rows = np.floor(np.sqrt(nphi)).astype(int)
1696 cols = np.ceil(nphi / rows).astype(int)
File ~/Research/DESC/desc/equilibrium/coords.py:218, in map_coordinates(eq, coords, inbasis, outbasis, guess, params, period, tol, maxiter, full_output, **kwargs)
201 vecroot = jit(
202 vmap(
203 lambda x0, *p: root(
(...)
213 )
214 )
215 # See description here
216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
217 # except we make sure properly handle periodic coordinates.
--> 218 yk, (res, niter) = vecroot(yk, coords)
220 out = compute(yk, outbasis)
221 if full_output:
[... skipping hidden 14 frame]
File ~/Research/DESC/desc/equilibrium/coords.py:203, in map_coordinates.<locals>.<lambda>(x0, *p)
197 yk = _initial_guess_nn_search(coords, inbasis, eq, period, compute)
199 yk = fixup(yk)
201 vecroot = jit(
202 vmap(
--> 203 lambda x0, *p: root(
204 residual,
205 x0,
206 jac=jac,
207 args=p,
208 fixup=fixup,
209 tol=tol,
210 maxiter=maxiter,
211 **kwargs,
212 )
213 )
214 )
215 # See description here
216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
217 # except we make sure properly handle periodic coordinates.
218 yk, (res, niter) = vecroot(yk, coords)
File ~/Research/DESC/desc/backend.py:398, in root(fun, x0, jac, args, tol, maxiter, maxiter_ls, alpha, fixup)
395 A = jnp.atleast_2d(jax.jacfwd(g)(y))
396 return _lstsq(A, jnp.atleast_1d(y))
--> 398 x, (res, niter) = jax.lax.custom_root(
399 res, x0, solve, tangent_solve, has_aux=True
400 )
401 return x, (safenorm(res), niter)
[... skipping hidden 14 frame]
File ~/Research/DESC/desc/backend.py:344, in root.<locals>.<lambda>(x)
341 else:
342 jac2 = lambda x: jnp.atleast_2d(jac(x, *args))
--> 344 res = lambda x: jnp.atleast_1d(fun(x, *args)).flatten()
346 # want to use least squares for rank-defficient systems, but
347 # jnp.linalg.lstsq doesn't have JVP defined and is slower than needed
348 # so we use the normal equations with regularized cholesky
349 def _lstsq(a, b):
[... skipping hidden 47 frame]
File ~/Research/DESC/desc/basis.py:1816, in _jacobi_jvp(x, xdot)
1811 df = _jacobi(n, alpha, beta, x, dx + 1)
1812 # in theory n, alpha, beta, dx aren't differentiable (they're integers)
1813 # but marking them as non-diff argnums seems to cause escaped tracer values.
1814 # probably a more elegant fix, but just setting those derivatives to zero seems
1815 # to work fine.
-> 1816 return f, df * xdot + 0 * ndot + 0 * alphadot + 0 * betadot + 0 * dxdot
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:1036, in _forward_operator_to_aval.<locals>.op(self, *args)
1035 def op(self, *args):
-> 1036 return getattr(self.aval, f\"_{name}\")(self, *args)
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:573, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
571 args = (other, self) if swap else (self, other)
572 if isinstance(other, _accepted_binop_types):
--> 573 return binary_op(*args)
574 # Note: don't use isinstance here, because we don't want to raise for
575 # subclasses, e.g. NamedTuple objects that may override operators.
576 if type(other) in _rejected_binop_types:
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/ufunc_api.py:177, in ufunc.__call__(self, out, where, *args)
175 raise NotImplementedError(f\"where argument of {self}\")
176 call = self.__static_props['call'] or self._call_vectorized
--> 177 return call(*args)
[... skipping hidden 11 frame]
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py:1142, in _multiply(x, y)
1115 @partial(jit, inline=True)
1116 def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
1117 \"\"\"Multiply two arrays element-wise.
1118
1119 JAX implementation of :obj:`numpy.multiply`. This is a universal function,
(...)
1140 Array([ 0, 10, 20, 30], dtype=int32)
1141 \"\"\"
-> 1142 x, y = promote_args(\"multiply\", x, y)
1143 return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/util.py:355, in promote_args(fun_name, *args)
353 \"\"\"Convenience function to apply Numpy argument shape and dtype promotion.\"\"\"
354 check_arraylike(fun_name, *args)
--> 355 _check_no_float0s(fun_name, *args)
356 check_for_prngkeys(fun_name, *args)
357 return promote_shapes(fun_name, *promote_dtypes(*args))
File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/util.py:326, in check_no_float0s(fun_name, *args)
324 \"\"\"Check if none of the args have dtype float0.\"\"\"
325 if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
--> 326 raise TypeError(
327 f\"Called {fun_name} with a float0 array. \"
328 \"float0s do not support any operations by design because they \"
329 \"are not compatible with non-trivial vector spaces. No implicit dtype \"
330 \"conversion is done. You can use np.zeros_like(arr, dtype=np.float) \"
331 \"to cast a float0 array to a regular zeros array. \
\"
332 \"If you didn't expect to get a float0 you might have accidentally \"
333 \"taken a gradient with respect to an integer argument.\")
TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array.
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument."
}