`Record` support for `==` and `!=`
Description of new feature
I was surprised to find that this doesn't work for Records or RecordArrays:
>>> import awkward as ak
>>> record = ak.Record({'a': 1})
>>> record == record
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/user/miniconda3/envs/func_adl_uproot_rc/lib/python3.10/site-packages/numpy/lib/mixins.py", line 21, in func
return ufunc(self, other)
File "/home/user/miniconda3/envs/func_adl_uproot_rc/lib/python3.10/site-packages/awkward/highlevel.py", line 2028, in __array_ufunc__
return ak._connect.numpy.array_ufunc(ufunc, method, inputs, kwargs)
File "/home/user/miniconda3/envs/func_adl_uproot_rc/lib/python3.10/site-packages/awkward/_connect/numpy.py", line 294, in array_ufunc
out = ak._broadcasting.broadcast_and_apply(
File "/home/user/miniconda3/envs/func_adl_uproot_rc/lib/python3.10/site-packages/awkward/_broadcasting.py", line 1061, in broadcast_and_apply
out = apply_step(
File "/home/user/miniconda3/envs/func_adl_uproot_rc/lib/python3.10/site-packages/awkward/_broadcasting.py", line 1040, in apply_step
return continuation()
File "/home/user/miniconda3/envs/func_adl_uproot_rc/lib/python3.10/site-packages/awkward/_broadcasting.py", line 757, in continuation
outcontent = apply_step(
File "/home/user/miniconda3/envs/func_adl_uproot_rc/lib/python3.10/site-packages/awkward/_broadcasting.py", line 1040, in apply_step
return continuation()
File "/home/user/miniconda3/envs/func_adl_uproot_rc/lib/python3.10/site-packages/awkward/_broadcasting.py", line 473, in continuation
return apply_step(
File "/home/user/miniconda3/envs/func_adl_uproot_rc/lib/python3.10/site-packages/awkward/_broadcasting.py", line 1040, in apply_step
return continuation()
File "/home/user/miniconda3/envs/func_adl_uproot_rc/lib/python3.10/site-packages/awkward/_broadcasting.py", line 953, in continuation
raise ak._errors.wrap_error(
ValueError: while calling
numpy.equal.__call__(
<Record {a: 1} type='{a: int64}'>
<Record {a: 1} type='{a: int64}'>
)
Error details: cannot broadcast records in equal
It looks like this traces back to https://github.com/scikit-hep/awkward/issues/457. To quote from there:
Note that NumPy does not define such an operation on structured arrays:
>>> np_array = np.array([(1, 1.1), (2, 2.2), (3, 3.3)], [("x", int), ("y", float)]) >>> np_array array([(1, 1.1), (2, 2.2), (3, 3.3)], dtype=[('x', '<i8'), ('y', '<f8')]) >>> np_array + 1 Traceback (most recent call last): File "<stdin>", line 1, in <module> TypeError: invalid type promotion
This is true for most ufuncs, but NumPy does support this for equality and inequality:
>>> import numpy as np
>>> np_array = np.array([(1, 1.1), (2, 2.2), (3, 3.3)], [("x", int), ("y", float)])
>>> np_array == np_array
array([ True, True, True])
>>> np_array != np_array
array([False, False, False])
In the meeting, we talked about this: records (at any depth) are prevented from evaluating ufuncs to avoid some very subtle bugs if someone thinks that behaviors (like Vector) are installed but they are not installed. In particular, we don't want records adding like
{"rho": 1, "phi": 2} + {"rho": 3, "phi": 0} → {"rho": 4, "phi": 2}
because vector.register_awkward() didn't get called. NumPy's structured arrays would do this, too, but it's less of a pain point because it's easier to know that you have Vector's NumPy subclass. Behaviors aren't implicitly installed through ak.behavior (because NumPy structures aren't as deep).
However, you have a good point that == and != should be exceptions to this rule. Although someone could define custom overloads for == and != that are different from just checking to see if all of the fields match, it's usually a good assumption that equality means "all of the fields match." So, equality/inequality is special: it's not like addition or other operations, for which the naive rule is likely very bad (do the wrong calculation without warning).
We can handle this by putting == and != overloads in the built-in Awkward ak.behavior, just as we have for the string behaviors. It would be a new submodule in the awkward/behaviors directory.
Hello,
@kpachal, @mswiatlo, and I are using awkward in the development of Coffea with @lgray. It is going well so far but we are having a problem very similar to this. It is not the exact same so we are happy to open a new issue if you prefer.
We would like to be able to compare nested arrays with == and !=. This is something that can be done in numpy:
>>> import numpy as np
>>> l1 = [[1],[2]]
>>> l2 = [[1,2],[1,2]]
>>> np.array(l1) == np.array(l2)
array([[ True, False],
[False, True]])
However, when we try to do the same with awkward:
>>> import awkward as ak
>>> l1 = [[1],[2]]
>>> l2 = [[1,2],[1,2]]
>>> ak.Array(l1) == ak.Array(l2)
Traceback (most recent call last)
File ~/anaconda3/lib/python3.10/site-packages/awkward/highlevel.py:1356, in Array.__array_ufunc__(self, ufunc, method, *inputs, **kwargs)
1355 with ak._errors.OperationErrorContext(name, inputs, kwargs):
-> 1356 return ak._connect.numpy.array_ufunc(ufunc, method, inputs, kwargs)
File ~/anaconda3/lib/python3.10/site-packages/awkward/_connect/numpy.py:326, in array_ufunc(ufunc, method, inputs, kwargs)
325 else:
--> 326 out = ak._broadcasting.broadcast_and_apply(
327 inputs, action, behavior, allow_records=False, function_name=ufunc.__name__
328 )
329 assert isinstance(out, tuple) and len(out) == 1
File ~/anaconda3/lib/python3.10/site-packages/awkward/_broadcasting.py:1057, in broadcast_and_apply(inputs, action, behavior, depth_context, lateral_context, allow_records, left_broadcast, right_broadcast, numpy_to_regular, regular_to_jagged, function_name, broadcast_parameters_rule)
1056 isscalar = []
-> 1057 out = apply_step(
1058 backend,
1059 broadcast_pack(inputs, isscalar),
1060 action,
1061 0,
1062 depth_context,
1063 lateral_context,
1064 behavior,
1065 {
1066 "allow_records": allow_records,
1067 "left_broadcast": left_broadcast,
1068 "right_broadcast": right_broadcast,
1069 "numpy_to_regular": numpy_to_regular,
1070 "regular_to_jagged": regular_to_jagged,
1071 "function_name": function_name,
1072 "broadcast_parameters_rule": broadcast_parameters_rule,
1073 },
1074 )
1075 assert isinstance(out, tuple)
File ~/anaconda3/lib/python3.10/site-packages/awkward/_broadcasting.py:1036, in apply_step(backend, inputs, action, depth, depth_context, lateral_context, behavior, options)
1035 elif result is None:
-> 1036 return continuation()
1037 else:
File ~/anaconda3/lib/python3.10/site-packages/awkward/_broadcasting.py:1009, in apply_step.<locals>.continuation()
1008 elif any(x.is_list for x in contents):
-> 1009 return broadcast_any_list()
1011 # Any RecordArrays?
File ~/anaconda3/lib/python3.10/site-packages/awkward/_broadcasting.py:586, in apply_step.<locals>.broadcast_any_list()
584 nextinputs.append(x)
--> 586 outcontent = apply_step(
587 backend,
588 nextinputs,
589 action,
590 depth + 1,
591 copy.copy(depth_context),
592 lateral_context,
593 behavior,
594 options,
595 )
596 assert isinstance(outcontent, tuple)
File ~/anaconda3/lib/python3.10/site-packages/awkward/_broadcasting.py:1036, in apply_step(backend, inputs, action, depth, depth_context, lateral_context, behavior, options)
1035 elif result is None:
-> 1036 return continuation()
1037 else:
File ~/anaconda3/lib/python3.10/site-packages/awkward/_broadcasting.py:1009, in apply_step.<locals>.continuation()
1008 elif any(x.is_list for x in contents):
-> 1009 return broadcast_any_list()
1011 # Any RecordArrays?
File ~/anaconda3/lib/python3.10/site-packages/awkward/_broadcasting.py:725, in apply_step.<locals>.broadcast_any_list()
724 elif isinstance(x, listtypes):
--> 725 nextinputs.append(x._broadcast_tooffsets64(offsets).content)
726 # Handle implicit left-broadcasting (non-NumPy-like broadcasting).
File ~/anaconda3/lib/python3.10/site-packages/awkward/contents/listoffsetarray.py:397, in ListOffsetArray._broadcast_tooffsets64(self, offsets)
394 if index_nplike.known_data and not index_nplike.array_equal(
395 this_zero_offsets, offsets
396 ):
--> 397 raise ValueError("cannot broadcast nested list")
399 return ListOffsetArray(
400 offsets, next_content[: offsets[-1]], parameters=self._parameters
401 )
ValueError: cannot broadcast nested list
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Cell In[22], line 5
2 l2 = [[1,2],[1,2]]
4 print(np.array(l1) == np.array(l2))
----> 5 print(ak.Array(l1) == ak.Array(l2))
File ~/anaconda3/lib/python3.10/site-packages/awkward/_operators.py:50, in _binary_method.<locals>.func(self, other)
48 if _disables_array_ufunc(other):
49 return NotImplemented
---> 50 return ufunc(self, other)
File ~/anaconda3/lib/python3.10/site-packages/awkward/highlevel.py:1355, in Array.__array_ufunc__(self, ufunc, method, *inputs, **kwargs)
1290 """
1291 Intercepts attempts to pass this Array to a NumPy
1292 [universal functions](https://docs.scipy.org/doc/numpy/reference/ufuncs.html)
(...)
1352 See also #__array_function__.
1353 """
1354 name = f"{type(ufunc).__module__}.{ufunc.__name__}.{method!s}"
-> 1355 with ak._errors.OperationErrorContext(name, inputs, kwargs):
1356 return ak._connect.numpy.array_ufunc(ufunc, method, inputs, kwargs)
File ~/anaconda3/lib/python3.10/site-packages/awkward/_errors.py:63, in ErrorContext.__exit__(self, exception_type, exception_value, traceback)
60 try:
61 # Handle caught exception
62 if exception_type is not None and self.primary() is self:
---> 63 self.handle_exception(exception_type, exception_value)
64 finally:
65 # `_kwargs` may hold cyclic references, that we really want to avoid
66 # as this can lead to large buffers remaining in memory for longer than absolutely necessary
67 # Let's just clear this, now.
68 self._kwargs.clear()
File ~/anaconda3/lib/python3.10/site-packages/awkward/_errors.py:78, in ErrorContext.handle_exception(self, cls, exception)
76 self.decorate_exception(cls, exception)
77 else:
---> 78 raise self.decorate_exception(cls, exception)
ValueError: cannot broadcast nested list
This error occurred while calling
numpy.equal.__call__(
<Array [[1], [2]] type='2 * var * int64'>
<Array [[1, 2], [1, 2]] type='2 * var * int64'>
)
Thank you for the help and please let us know if we should open this as a new issue.
Hi @jbrewster7, these two arrays are not considered broadcastable because they both have ragged, unequal sublist lengths. The broadcasting rules are outlined partially here.
>>> import awkward as ak
>>> l1 = ak.Array([[1],[2]])
>>> l1.type.show()
2 * var * int64
>>> l2 = ak.Array([[[1,2],[1,2]]])
>>> l2.type.show()
1 * var * var * int64
If you want to perform length-1 broadcasting, as NumPy does, then the length-1 sublists must be regular. You can ensure this by converting from a ragged dimension to a regular one, using ak.to_regular:
l1 = ak.Array(ak.to_regular([[1],[2]], axis=1))
l2 = ak.Array([[[1,2],[1,2]]])
(l1 == l2).show(type=True)
This answer is terse; I'm short on time. But I know that Jim will likely follow up here :)
However, you have a good point that
==and!=should be exceptions to this rule. Although someone could define custom overloads for==and!=that are different from just checking to see if all of the fields match, it's usually a good assumption that equality means "all of the fields match." So, equality/inequality is special: it's not like addition or other operations, for which the naive rule is likely very bad (do the wrong calculation without warning).We can handle this by putting
==and!=overloads in the built-in Awkwardak.behavior, just as we have for the string behaviors. It would be a new submodule in theawkward/behaviorsdirectory.
Strings are no longer implemented through behaviors, but I think it's still true that == and !=, and only these comparisons, should be implemented for RecordArrays. Just as for strings, it should be hard-coded: if a Record type doesn't have an overload defined for == and !=, the default should be to do (== for all) or (!= for any) of the fields. The alternative (current behavior) is to raise an error, so this wouldn't be breaking any user code that currently works.