awkward
awkward copied to clipboard
Custom behaviors plus jax leading to lookup in wrong spot
Version of Awkward Array
ce63bf2
Description and code to reproduce
This is partner issue to https://github.com/CoffeaTeam/coffea/issues/874 as perhaps this is more on the side of awkward than coffea. I am trying to combine custom behaviors (defined by coffea) with the jax backend of awkward. The reproducer below results in:
AttributeError: module 'jax.numpy' has no attribute '_mass2_kernel'
Reproducer:
import awkward as ak
from coffea.nanoevents.methods import candidate
import numpy as np
import uproot
ak.jax.register_and_check()
ak.behavior.update(candidate.behavior)
ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
"raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"
with uproot.open(ttbar_file) as f:
arr = f["Events"].arrays(["Electron_pt", "Electron_eta", "Electron_phi",
"Electron_mass", "Electron_charge"])
px = arr.Electron_pt * np.cos(arr.Electron_phi)
py = arr.Electron_pt * np.sin(arr.Electron_phi)
pz = arr.Electron_pt * np.sinh(arr.Electron_eta)
E = np.sqrt(arr.Electron_mass**2 + px**2 + py**2 + pz**2)
evtfilter = ak.num(arr["Electron_pt"]) >= 2
els = ak.zip({"pt": arr.Electron_pt, "eta": arr.Electron_eta, "phi": arr.Electron_phi,
"energy": E, "charge": arr.Electron_charge}, with_name="PtEtaPhiECandidate")[evtfilter]
els = ak.to_backend(els, "jax")
(els[:, 0] + els[:, 1]).mass
Using the "Momentum4D" behavior from vector (after vector.register_awkward()) works. Skipping the backend conversion to jax also makes this work.
Full trace
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[1], line 32
28 els = ak.zip({"pt": arr.Electron_pt, "eta": arr.Electron_eta, "phi": arr.Electron_phi,
29 "energy": E, "charge": arr.Electron_charge}, with_name="PtEtaPhiECandidate")[evtfilter]
30 els = ak.to_backend(els, "jax")
---> 32 (els[:, 0] + els[:, 1]).mass
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/highlevel.py:1097, in Array.__getattr__(self, where)
1061 """
1062 Args:
1063 where (str): Attribute name to lookup
(...)
1094 *assigned* as attributes. See #ak.Array.__setitem__ for more.
1095 """
1096 if hasattr(type(self), where):
-> 1097 return super().__getattribute__(where)
1098 else:
1099 if where in self._layout.fields:
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/coffea/nanoevents/methods/vector.py:531, in LorentzVector.mass(self)
525 @property
526 def mass(self):
527 r"""Invariant mass (+, -, -, -)
528
529 :math:`\sqrt{t^2-x^2-y^2-z^2}`
530 """
--> 531 return numpy.sqrt(self.mass2)
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/highlevel.py:1097, in Array.__getattr__(self, where)
1061 """
1062 Args:
1063 where (str): Attribute name to lookup
(...)
1094 *assigned* as attributes. See #ak.Array.__setitem__ for more.
1095 """
1096 if hasattr(type(self), where):
-> 1097 return super().__getattribute__(where)
1098 else:
1099 if where in self._layout.fields:
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/coffea/nanoevents/methods/vector.py:523, in LorentzVector.mass2(self)
520 @property
521 def mass2(self):
522 """Squared `mass`"""
--> 523 return _mass2_kernel(self.t, self.x, self.y, self.z)
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/highlevel.py:1349, in Array.__array_ufunc__(self, ufunc, method, *inputs, **kwargs)
1347 name = f"{type(ufunc).__module__}.{ufunc.__name__}.{method!s}"
1348 with ak._errors.OperationErrorContext(name, inputs, kwargs):
-> 1349 return ak._connect.numpy.array_ufunc(ufunc, method, inputs, kwargs)
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_connect/numpy.py:459, in array_ufunc(ufunc, method, inputs, kwargs)
450 out = ak._do.recursively_apply(
451 inputs[where],
452 unary_action,
(...)
455 allow_records=False,
456 )
458 else:
--> 459 out = ak._broadcasting.broadcast_and_apply(
460 inputs, action, behavior, allow_records=False, function_name=ufunc.__name__
461 )
462 assert isinstance(out, tuple) and len(out) == 1
463 out = out[0]
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_broadcasting.py:1022, in broadcast_and_apply(inputs, action, behavior, depth_context, lateral_context, allow_records, left_broadcast, right_broadcast, numpy_to_regular, regular_to_jagged, function_name, broadcast_parameters_rule)
1020 backend = backend_of(*inputs)
1021 isscalar = []
-> 1022 out = apply_step(
1023 backend,
1024 broadcast_pack(inputs, isscalar),
1025 action,
1026 0,
1027 depth_context,
1028 lateral_context,
1029 behavior,
1030 {
1031 "allow_records": allow_records,
1032 "left_broadcast": left_broadcast,
1033 "right_broadcast": right_broadcast,
1034 "numpy_to_regular": numpy_to_regular,
1035 "regular_to_jagged": regular_to_jagged,
1036 "function_name": function_name,
1037 "broadcast_parameters_rule": broadcast_parameters_rule,
1038 },
1039 )
1040 assert isinstance(out, tuple)
1041 return tuple(broadcast_unpack(x, isscalar, backend) for x in out)
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_broadcasting.py:1001, in apply_step(backend, inputs, action, depth, depth_context, lateral_context, behavior, options)
999 return result
1000 elif result is None:
-> 1001 return continuation()
1002 else:
1003 raise AssertionError(result)
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_broadcasting.py:974, in apply_step.<locals>.continuation()
972 # Any non-string list-types?
973 elif any(x.is_list and not is_string_like(x) for x in contents):
--> 974 return broadcast_any_list()
976 # Any RecordArrays?
977 elif any(x.is_record for x in contents):
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_broadcasting.py:622, in apply_step.<locals>.broadcast_any_list()
619 nextinputs.append(x)
620 nextparameters.append(NO_PARAMETERS)
--> 622 outcontent = apply_step(
623 backend,
624 nextinputs,
625 action,
626 depth + 1,
627 copy.copy(depth_context),
628 lateral_context,
629 behavior,
630 options,
631 )
632 assert isinstance(outcontent, tuple)
633 parameters = parameters_factory(nextparameters, len(outcontent))
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_broadcasting.py:987, in apply_step(backend, inputs, action, depth, depth_context, lateral_context, behavior, options)
980 else:
981 raise ValueError(
982 "cannot broadcast: {}{}".format(
983 ", ".join(repr(type(x)) for x in inputs), in_function(options)
984 )
985 )
--> 987 result = action(
988 inputs,
989 depth=depth,
990 depth_context=depth_context,
991 lateral_context=lateral_context,
992 continuation=continuation,
993 behavior=behavior,
994 backend=backend,
995 options=options,
996 )
998 if isinstance(result, tuple) and all(isinstance(x, Content) for x in result):
999 return result
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_connect/numpy.py:400, in array_ufunc.<locals>.action(inputs, **ignore)
397 args.append(x)
399 # Give backend a chance to change the ufunc implementation
--> 400 impl = backend.prepare_ufunc(ufunc)
402 # Invoke ufunc
403 result = impl(*args, **kwargs)
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_backends/jax.py:50, in JaxBackend.prepare_ufunc(self, ufunc)
47 def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike:
48 from awkward._connect.jax import get_jax_ufunc
---> 50 return get_jax_ufunc(ufunc)
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/awkward/_connect/jax/__init__.py:8, in get_jax_ufunc(ufunc)
7 def get_jax_ufunc(ufunc):
----> 8 return getattr(jax.numpy, ufunc.__name__)
File ~/mambaforge/envs/agc-ad/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name)
51 warnings.warn(message, DeprecationWarning, stacklevel=2)
52 return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute '_mass2_kernel'
This error occurred while calling
numpy._mass2_kernel.__call__(
<Array [192.54099, 132.60043, ..., 142.34727] type='5 * float32'>
<Array [5.5301285, -46.949707, ..., -58.96562] type='5 * float32'>
<Array [-70.93436, -12.467135, ..., -31.510773] type='5 * float32'>
<Array [156.38907, -75.47587, ..., -115.080734] type='5 * float32'>
)
I made some progress understanding what causes this to happen. Here is a significantly simplified reproducer:
import awkward as ak
import numba
import numpy as np
behavior = {}
ak.jax.register_and_check()
USE_JAX = False # set to False to run this successfully
input_arr = ak.Array([1.0], backend=("jax" if USE_JAX else "cpu"))
@numba.vectorize(
[
numba.float32(numba.float32, numba.float32),
numba.float64(numba.float64, numba.float64),
]
)
def _some_kernel(x, y):
return x * x + y * y
@ak.mixin_class(behavior)
class SomeClass:
@property
def some_kernel(self):
return _some_kernel(self.x, self.y)
ak.behavior.update(behavior)
arr = ak.zip({"x": input_arr, "y": input_arr}, with_name="SomeClass")
arr.some_kernel # crashes with Jax
This results in
AttributeError: module 'jax.numpy' has no attribute '_some_kernel'
This error occurred while calling
numpy._some_kernel.__call__(
<Array [1.0] type='1 * float32'>
<Array [1.0] type='1 * float32'>
)
The code runs successfully with USE_JAX = False. It also works fine when removing the @numba.vectorize decorator from the kernel. I imagine numba + jax are just generically incompatible here. If that is the case and it is expected that this setup does not work, maybe there is a way to improve the error message for such a setup.
Right - at the moment, users can't override ufuncs for JAX, so numba ufuncs throw exceptions. Numba functions wouldn't be differentiable via JAX; we'd need to substitute a JAX implementation.
@Saransh-cpp, this is another one that you should self-assign (anything with label autodiff, actually).
The coffea issue will be solved once their vector module is removed and scikit-hep/vector is recommended to the users - https://github.com/CoffeaTeam/coffea/issues/874#issuecomment-1941435897
For the issue on the awkward end, I am a bit confused regarding how we want the ideal behavior to look like -
- Do we want the users to refrain from using Jax and Numba together, or would we like to support doing that (if that is possible)? Or do we just want better error handling here?
- Do we want to recommend Jax's
jitmechanism to users when they plan on differentiating their functions? I have not triedjax.jitwith awkward, but it might just work.
Thanks!
jax.jit will not work in Awkward—that was something that we determined very early on. Looking at it, it was clear that it would never work because so many of the Awkward kernels need to determine a new buffer's length from an old buffer's values, and that is forbidden in JAX. JAX users find it hard enough to not be able to apply a boolean mask in compiled JAX (because the output array length depends on how many True values are in the mask), but Awkward has to do that sort of thing a lot.
So with JAX's JIT-compilation off the table, the alternative of compiling in Numba is still there, but Numba does not propagate derivatives through its compiled code. Starting in January 2022 and (I was following it) until January 2023, @ludgerpaehler was trying to compile through Numba by using Enzyme, an autograd tool for LLVM code. I don't know the current state of that project, but that would allow us to connect JAX's non-JITted autograd with Numba's JITted autograd. Users already have to switch programming models between non-JIT and JIT, but in principle, it's possible to preserve derivatives across that boundary.