aesara
aesara copied to clipboard
Handle duplicate indices in Numba implementation of `AdvancedIncSubtensor1`
Here are a few important guidelines and requirements to check before your PR can be merged:
- [x] There is an informative high-level description of the changes.
- [x] The description and/or commit message(s) references the relevant GitHub issue(s).
- [ ]
pre-commitis installed and set up. - [x] The commit messages follow these guidelines.
- [x] The commits correspond to relevant logical changes, and there are no commits that fix changes introduced by other commits in the same branch/BR.
- [x] There are tests covering the changes introduced in the PR.
First, this PR removes the incorrect numba implementation of AdvancedIncSubtensor, so this will now just fall back to objectmode, and be slow and correct. (We should also provide a numba impl for this, but that will be a separate PR).
But we do add a new implementation for AdvancedIncSubtensor1 (the much more common case). Here, we also take advantage of the fact that sometimes we know the indices beforehand, so we can simplify bounds checks, and generate cleaner and faster assembly code.
Also a one line fix for cumop accidentally made its way into the PR, but this is simple enough that maybe we can just keep it here? (it has its own commit).
The problem was that the numba.prange loop had data races.
fixes https://github.com/aesara-devs/aesara/issues/603
Codecov Report
Merging #1081 (1e6be72) into main (f2a7fb9) will decrease coverage by
0.04%. The diff coverage is85.48%.
:exclamation: Current head 1e6be72 differs from pull request most recent head f3f65ad. Consider uploading reports for the commit f3f65ad to get more accurate results
Additional details and impacted files
@@ Coverage Diff @@
## main #1081 +/- ##
==========================================
- Coverage 79.28% 79.24% -0.05%
==========================================
Files 159 152 -7
Lines 48111 48002 -109
Branches 10937 10922 -15
==========================================
- Hits 38145 38038 -107
- Misses 7454 7458 +4
+ Partials 2512 2506 -6
| Impacted Files | Coverage Δ | |
|---|---|---|
| aesara/link/numba/dispatch/basic.py | 90.73% <85.24%> (-1.34%) |
:arrow_down: |
| aesara/link/numba/dispatch/extra_ops.py | 98.00% <100.00%> (ø) |
|
| aesara/tensor/basic_opt.py | 85.90% <0.00%> (-14.10%) |
:arrow_down: |
| aesara/tensor/subtensor_opt.py | 87.12% <0.00%> (-12.88%) |
:arrow_down: |
| aesara/tensor/math_opt.py | 87.27% <0.00%> (-12.73%) |
:arrow_down: |
| aesara/tensor/opt_uncanonicalize.py | 96.03% <0.00%> (-3.97%) |
:arrow_down: |
| aesara/ifelse.py | 49.71% <0.00%> (-1.30%) |
:arrow_down: |
| aesara/link/numba/dispatch/random.py | 97.74% <0.00%> (-1.09%) |
:arrow_down: |
| aesara/sandbox/multinomial.py | 75.91% <0.00%> (-0.61%) |
:arrow_down: |
| ... and 58 more |
Is it possible that codecov doesn't pick up changes when I force-push? Somehow I still see the old version of the patch in the coverage report...
Comparison of the constant index case with the non-constant index case:
%env NUMBA_BOUNDSCHECK=0
import aesara
import aesara.tensor as at
import numpy as np
n, k = 100_000, 100
idxs_vals = np.random.randint(k, size=n)
#idxs_vals.sort()
x_vals = np.random.randn(k)
a_vals = np.random.randn(n)
x = at.dvector("x")
a = at.dvector("d")
idxs = at.vector("idx", dtype=np.int64)
out = at.inc_subtensor(x[idxs], a)
func = aesara.function([idxs, x, a], out, mode="NUMBA")
func_inner = func.vm.jit_fn
_ = func_inner(idxs_vals, x_vals, a_vals)
print("time with non-const index:")
%timeit func_inner(idxs_vals, x_vals, a_vals)
x = at.dvector("x")
a = at.dvector("d")
out = at.inc_subtensor(x[idxs_vals], a)
func = aesara.function([x, a], out, mode="NUMBA")
func_inner = func.vm.jit_fn
func_inner(x_vals, a_vals);
print("time with const index:")
%timeit func_inner(x_vals, a_vals)
time with non-const index:
90.9 µs ± 3.61 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
time with const index:
62.7 µs ± 974 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
So the non-const index case is about 1.4x slower. If we enable boundschecks, this increases to ~2x.
The asm of the non-const version without boundschecks looks ok, but it has to deal with the possibility of negative indices:
.LBB3_13:
movq (%rdx), %rax
vmovsd (%rsi), %xmm0
movq %rax, %rdi
sarq $63, %rdi
andq %r12, %rdi
addq %rax, %rdi
vaddsd (%rbx,%rdi,8), %xmm0, %xmm0
vmovsd %xmm0, (%rbx,%rdi,8)
leaq (%r9,%rdx), %rax
movq (%r9,%rdx), %rdx
leaq (%rcx,%rsi), %rdi
vmovsd (%rcx,%rsi), %xmm0
movq %rdx, %rsi
sarq $63, %rsi
andq %r12, %rsi
addq %rdx, %rsi
vaddsd (%rbx,%rsi,8), %xmm0, %xmm0
vmovsd %xmm0, (%rbx,%rsi,8)
leaq (%r9,%rax), %rsi
movq (%r9,%rax), %rax
leaq (%rcx,%rdi), %rbp
vmovsd (%rcx,%rdi), %xmm0
movq %rax, %rdx
sarq $63, %rdx
andq %r12, %rdx
addq %rax, %rdx
vaddsd (%rbx,%rdx,8), %xmm0, %xmm0
vmovsd %xmm0, (%rbx,%rdx,8)
leaq (%r9,%rsi), %rdx
movq (%r9,%rsi), %rax
vmovsd (%rcx,%rbp), %xmm0
movq %rax, %rdi
sarq $63, %rdi
andq %r12, %rdi
addq %rax, %rdi
vaddsd (%rbx,%rdi,8), %xmm0, %xmm0
leaq (%rcx,%rbp), %rsi
vmovsd %xmm0, (%rbx,%rdi,8)
addq %r9, %rdx
addq %rcx, %rsi
addq $-4, %r15
jne .LBB3_13
If we enable boundschecks, we get branching in the loop:
.LBB3_8:
movq (%rdx), %rbx
movq %rbx, %rax
sarq $63, %rax
andq %r12, %rax
addq %rbx, %rax
cmpq %r12, %rax
jge .LBB3_10
testq %rax, %rax
js .LBB3_10
vmovsd (%rdi), %xmm0
vaddsd (%rbp,%rax,8), %xmm0, %xmm0
vmovsd %xmm0, (%rbp,%rax,8)
addq %rcx, %rdx
addq %rsi, %rdi
decq %r14
jne .LBB3_8
In comparison the constant-index case looks pretty nice:
.LBB3_10:
movq (%rdi,%rbx), %rax
vmovsd (%rcx), %xmm0
vaddsd (%rbp,%rax,8), %xmm0, %xmm0
vmovsd %xmm0, (%rbp,%rax,8)
movq 8(%rdi,%rbx), %rax
vmovsd (%r8,%rcx), %xmm0
vaddsd (%rbp,%rax,8), %xmm0, %xmm0
vmovsd %xmm0, (%rbp,%rax,8)
movq 16(%rdi,%rbx), %rax
vmovsd (%rcx,%r8,2), %xmm0
vaddsd (%rbp,%rax,8), %xmm0, %xmm0
vmovsd %xmm0, (%rbp,%rax,8)
movq 24(%rdi,%rbx), %rax
vmovsd (%rsi,%rcx), %xmm0
vaddsd (%rbp,%rax,8), %xmm0, %xmm0
vmovsd %xmm0, (%rbp,%rax,8)
addq %rdx, %rcx
addq $32, %rdi
cmpq $8192, %rdi
jne .LBB3_10
And we get safe indexing even without boundscheck=True.
This also raises the question, how we want to deal with out of bounds access by default.
I'm not really comfortable with a default implementation that doesn't check bounds for indexing. Should we just enable boundschecking by default using a config option (right now we use the numba default boundscheck=False), and allow users to overwrite that if they ask explicitly? Or should we write ops like indexing that use user-input to be safe even when boundscheck=False?
You've demonstrated that there's possibly a clear difference between constant and non-constant inputs, but we really need to know whether or not all the extra code is providing value, and only a comparison with and without it would help determine that. Also, it's better if you provide the generated LLVM IR instead of the ASM generated for your machine.
I turned part of it into a rewrite, that makes it a bit cleaner. Apart from that I'm not really sure what extra code you are referring to.
numbas default for boundschecks is False, so unless we change that, this means that at.as_tensor_variable(np.zeros(2))[3] has undefined behavior. I think this is terrible API. We don't have this problem in most numba ops, because most of the time user input can't make us access incorrect memory even if boundschecks are off, because we (hopefully) control the bounds of loops correctly.
So I think we really need boundschecks of some sort by default.
We can however sometimes tell if boundschecks are unnecessary and optimize them away (something that numba can't do on its own), so why wouldn't we?
numbas default for boundschecks is False, so unless we change that, this means that
at.as_tensor_variable(np.zeros(2))[3]has undefined behavior. I think this is terrible API. We don't have this problem in most numba ops, because most of the time user input can't make us access incorrect memory even if boundschecks are off, because we (hopefully) control the bounds of loops correctly. So I think we really need boundschecks of some sort by default. We can however sometimes tell if boundschecks are unnecessary and optimize them away (something that numba can't do on its own), so why wouldn't we?
Aesara's responsibility is to faithfully preserve the results of explicitly defined computations for valid inputs when transpiling to Numba and other targets. Since most errors aren't specified in an Aesara graph—aside from RaiseOps—they don't fall under that responsibility. In general, we currently aren't trying to preserve all of the behavior of a single target—Python included—across all other transpilation targets.
Regardless, the scope of this PR—and its related issue—does not cover manual bounds checking. We can discuss it in a new issue or Discussion, though.
I'm actually a bit shocked you would accept something in aesara where we access invalid memory for wrong user input by default. I am not going to remove boundchecks from the PR, I'd feel responsible for all the headache that would lead to.
Aesara's responsibility is to faithfully preserve the results of explicitly defined computations for valid inputs when transpiling to Numba and other targets
Not sure where that is coming from. I'd say aesara's responsibility is to produce decent code, however that happens. And invalid memory access is certainly not that.
I'm actually a bit shocked you would accept something in aesara where we access invalid memory for wrong user input by default.
You seem to be aware that bounds checking already exists in Numba, but you're also assigning the same responsibility to Aesara. If you have this much disgust for a lack of bounds checking, then you need to take that up with Numba—and a few other programming languages, as well.
As I said above, if you found a bug in Numba's bounds checking that's solved by your implementation, please report it to them. We will always consider adding code that works around a current Numba bug or missing feature, but that doesn't seem to be the case here. If it is, tell us.
As you mentioned, we can override Numba's defaults and compile these graphs with bounds checking turned on by default. That's a viable approach. That's also a completely independent change; one that does not factor into the issue addressed by this PR.
I am not going to remove boundchecks from the PR, I'd feel responsible for all the headache that would lead to.
That's fine; we can take care of it.
Aesara's responsibility is to faithfully preserve the results of explicitly defined computations for valid inputs when transpiling to Numba and other targets
Not sure where that is coming from.
That's in reference to the kinds of computations that should be expressed in our Numba implementations of Aesara nodes. If you read the rest of what I wrote, you'll see how it relates to explicit error handling like the kind you've added.
I'd say aesara's responsibility is to produce decent code, however that happens. And invalid memory access is certainly not that.
Unless the code was explicitly designed/intended to prevent invalid memory accesses caused by bad user input, such an error says nothing about the quality of the code. It only says something about the purpose and/or expectations of the code. Sometimes the purpose/expectations for code involves performance, and bounds checking can hinder performance quite a bit. In that case, bounds checking would not make for decent code.
Regardless, unnecessary redundancy does not make code more decent, so, unless your additions are addressing something currently missing from Numba (as mentioned above), these changes are not more decent than the same code without the redundant bounds checking.
I kind of hope we are just talking past each other here, so I'll just summarize a bit, and hopefully that helps:
AdvancedIncSubtensor1 is using user defined indices, so I think there have to be boundschecks of some kind by default. I don't think this necessarily needs to be the case for most other ops, because most of the time we will only access invalid memory if there is a bug in aesara, but not if there is a bug in the user code.
Checked array access is the norm all over the python ecosystem (python itself, numpy, scipy, sklearn, pytorch, jax, tensorflow...). numba is the only exception I can think of right now, where it is not on by default. I don't actually agree with this default but I also think this is still much more reasonable in numba itself compared to aesara, because if you use numba directly it is much less hidden.
This is why I set boundschecks=True on the numba implementation of that op, so that we use the numba boundscheck support by default.
There is however a very common case where we can eliminate the boundschecks during graph execution and still have safe array access: If the array of indices is known at compile time we can simply pre-compute the maximum and minimum entry, and check at graph execution time if those are valid for the other input arrays.
I proposed two different implementations of that, one where this happens entirely in the numba dispatch, and one where I moved it to the graph itself using a rewrite.
Sounds like numba gives us the option to enable boundchecks but doesn't do it by default. What's wrong with opting to use that in Aesara?
We can have a numba-specific Aesara flag for disabling boundchecks introduced in numba Ops, if we don't want to add a numba specific variable at the Op level.
Also what's with that check_input class variable in these Ops that doesn't seem to be used for anything?
Sounds like numba gives us the option to enable boundchecks but doesn't do it by default. What's wrong with opting to use that in Aesara?
We can have a numba-specific Aesara flag for disabling boundchecks introduced in numba Ops, if we don't want to add a numba specific variable at the Op level.
Yes, if we want bounds checking in Numba, we need to use Numba's bound checking. That's it.
Also what's with that
check_inputclass variable in theseOps that doesn't seem to be used for anything?
I think that variable is used by COp to perform C-level checks/validation.
I give up. Sorry @ricardoV94...
Yes, if we want bounds checking in Numba, we need to use Numba's bound checking. That's it.
Isn't this what the current PR proposed? Or are you referring to the small check inner-function?
Isn't this what the current PR proposed? Or are you referring to the small
checkinner-function?
This PR brings bounds checking to our graphs by adding a new property to AdvancedIncSubtensor1 Ops; however, we don't benefit from having bounds checking at the graph-level (i.e. via explicit CheckAndRaise nodes).
Our Python and C backends already perform bounds checking, and extra work would be needed in order to provide versions that don't (and use the newly introduced property). Likewise, if we're going to do anything with bounds checking at the graph-level, we would need to do it for all *Subtensor* Ops, and not just AdvancedIncSubtensor1.
Also, has anyone considered how adding a new property like that to AdvancedIncSubtensor1 affects other operations (e.g. node merging)?
Anyway, the approach in this PR consists of a much bigger set of changes than we need. We only need a simple implementation in the spirit of
@numba_njit(boundscheck=boundscheck)
def advancedincsubtensor1(x, vals, idxs):
for idx, val in zip(idxs, vals):
x[idx] += val
return x
where boundscheck is potentially pulled from an aesara.config option or something similar. At the very least, we don't want to override a user's local/environment's Numba config options, so, if a user sets NUMBA_BOUNDSCHECK, Aesara should honor that value. That's a hard requirement.
Also, if we're adding bounds checking for AdvancedIncSubtensor, then we need to add it to all the other Numba *Subtensor* implementations. That's one of the reasons why this is a distinct issue that should be addressed in another PR.
I've create https://github.com/aesara-devs/aesara/pull/1143 to cover the issue associated with this PR. We can address the default bounds checking after that. In the meantime, if anyone wants (or ever wanted) bounds checking in Numba, they should be able to set NUMBA_BOUNDSCHECK, as Numba itself advises doing.