aesara
aesara copied to clipboard
Numba implementation of `AdvancedIncSubtensor` doesn't handle duplicate indices correctly
I am getting wrong results for the gradient of a simple Normal logp when using the NUMBA
backend (and indexing is involved). I have a gist documenting the problem here: https://gist.github.com/ricardoV94/51e6893faf0b9fc0ddb9b5f90a917af8
The bug seems to originating from this particular behavior of AdvancedIncSubtensor
Op
https://github.com/aesara-devs/aesara/blob/e51e87870f0a67fcfe1679e0aa75b5d3fbc0e995/aesara/tensor/subtensor.py#L2612-L2615
Essentially in that example the tensor being indexed in the end of the graph is actually [11.41636145, 17.46946674]
, the Aesara implementation returns the sum of that i.e. 28.88582819
but the Numba implementation simply returns the second element of that array.
The python code highlighted above doesn't have an Numba implementation so essentially it does the indexing only.
Can it be reproduced more directly?
Here's a minimal reproduction:
import numpy as np
import aesara
import aesara.tensor as at
x = at.as_tensor([0.0])
y = at.vector("y")
idx = at.as_tensor([0, 0])
z = at.inc_subtensor(x[idx], y)
z_numba_fn = aesara.function([y], z, mode="NUMBA")
z_c_fn = aesara.function([y], z)
z_numba_fn(np.r_[1, 1])
# array([1.])
z_c_fn(np.r_[1, 1])
# array([2.])
The underlying issue is that x[idx] += y
handles duplicate entries in idx
differently than np.add.at
(e.g. see #561).
I believe the approach used in #1081 for AdvancedIncSubtensor1
can be extended to arbitrary indices (i.e. AdvancedIncSubtensor
) by using the logic in the index conversion function expand_indices
to convert arbitrary indices into a single set of broadcasted indices that can be easily iterated over via a for
-loop in Numba.
The problem was only fixed for AdvancedIncSubtensor1
, so we still need a fix for AdvancedIncSubtensor
. We can try the approach mentioned above, which more or less implies a complete Numba implementation for AdvancedSubtensor
, or we can just use np.add.at
from object mode (or something else).