pytensor
pytensor copied to clipboard
Add `einsum`
Description
TODO:
- [x] Support ellipsis (related to https://github.com/dgasmith/opt_einsum/issues/235)
- [x] Exclude broadcastable dims for better perf (JAX does the same)
- [x] Handle missing static shape information (default to left to right contraction?)
- [x] Add rewrite for optimizing Einsum Ops when all inputs have known static shapes
- [x] Add rewrite for inlining optimized Einsum
- [x] Get rid of Blockwise Reshape
- [x] Fix lingering infinite rewriting bug
- [x] Decide on providing
optimizekwarg - [x] Appease Mypy
- [ ] Better docstrings (@jessegrabowski self-assigned)
- [x] Fix failing tests (@ricardoV94 self-assigned)
Related Issue
- [x] Closes #57
Checklist
- [ ] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
Are the current tests failing suppose to fail?
Looks like it's related to the changes to the __len__ method in variables.py. I'd suggest just reverting the change unless we really need it. It's a pretty low-level thing that would need a bit of work to figure out everything implicated.
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 80.82%. Comparing base (
28d9d4d) to head (e262999). Report is 1 commits behind head on main.
:exclamation: Current head e262999 differs from pull request most recent head 3fe3257
Please upload reports for the commit 3fe3257 to get more accurate results.
Additional details and impacted files
@@ Coverage Diff @@
## main #722 +/- ##
==========================================
- Coverage 80.89% 80.82% -0.07%
==========================================
Files 169 164 -5
Lines 46977 46844 -133
Branches 11478 11457 -21
==========================================
- Hits 38000 37862 -138
+ Misses 6767 6734 -33
- Partials 2210 2248 +38
| Files | Coverage Δ | |
|---|---|---|
| pytensor/compile/builders.py | 77.45% <100.00%> (-10.98%) |
:arrow_down: |
| pytensor/link/jax/dispatch/__init__.py | 100.00% <100.00%> (ø) |
|
| pytensor/link/jax/dispatch/einsum.py | 100.00% <100.00%> (ø) |
|
| pytensor/tensor/einsum.py | 100.00% <100.00%> (ø) |
All cases except those requiring tensordot with batch dims not on the left are passing
We may need more tests soon enough
Can we reuse how numpy implements tensordot?
On Tue, 7 May 2024, 11:52 Ricardo Vieira, @.***> wrote:
All cases except those requiring tensordot with batch dims not on the left are passing
We may need more tests soon enough
— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/pytensor/pull/722#issuecomment-2099086859, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACCUL3W5UGQ73SNQMCO2LZBEOQDAVCNFSM6AAAAABGPQHTOGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAOJZGA4DMOBVHE . You are receiving this because your review was requested.Message ID: @.***>
Can we reuse how numpy implements tensordot?
We already do that, but numpy doesn't have batched tensordot (except of course through einsum), but we already have batched tensordot working in this PR just not with arbitrary batch axis. Should just need some extra transposes to get the job done
Huhu convolutions via einsum work :D
Now Einsum also works with inputs with unknown static shape (unoptimized ofc). We can add a rewrite for when such Op is found with inputs that now have static shapes (this can be quite relevant in PyMC, when users use freeze_rv_and_dims on a model with mutable coords)
The ellipsis case is failing due to a bug in opt_einsum: https://github.com/dgasmith/opt_einsum/issues/235
Fix for the elipsis bug was merged yesterday. I installed from their main and the tests now pass, which is nice! I guess we will have to wait for them to cut a new release, though?
Fix for the elipsis bug was merged yesterday. I installed from their main and the tests now pass, which is nice! I guess we will have to wait for them to cut a new release, though?
I think we can create dummy empty arrays with the shapes to avoid the error for the time being.
I don't know how fast they release
Also tests that are failing on CI are passing for me locally, so I don't really know what's going on with that.
Also tests that are failing on CI are passing for me locally, so I don't really know what's going on with that.
I think one of the Blockwise reshape rewrites I added is buggy and causes infinite rewrite loop. Gotta investigate
I fixed the ellipsis by using dummy empty arrays with shape. Since all the functionality exists in numpy these days, I dropped the dependency on opt_einsum
Is there any opt in opt_einsum? I thought using that was value-add over plain numpy.
Is there any
optinopt_einsum? I thought using that was value-add over plain numpy.
The opt functionality of opt_einsum was incorporated into numpy itself sometime ago (literally line by line)
Getting close to the finish line. We have to add the optimize flag (or at least decide that we do not want to do it) and do the rewrite for when a graph is rebuild with static shapes at a later point
Getting close to the finish line. We have to add the optimize flag (or at least decide that we do not want to do it) and do the rewrite for when a graph is rebuild with static shapes at a later point
I like the optimize option but we agree it only makes sense when you have static shapes right? It's not required for an initial implementation right?
Not sure what you mean @zaxtax, I'm talking about allowing the "optimize" kwarg like there is in numpy, which defines what kind of optimization to do: optimize{bool, list, tuple, ‘greedy’, ‘optimal’}, users can pass their custom contraction path as well.
If users pass contraction_path, we don't need to know static shapes. If users set to greedy/optimal (optimal should be default), we need to know. But we may find them later only. If they don't want optimize, then we don't need to obviously
Oh I misunderstood!
On Thu, 11 Jul 2024, 12:38 Ricardo Vieira, @.***> wrote:
Not sure what you mean @zaxtax https://github.com/zaxtax, I'm talking about allowing the "optimize" kwarg like there is in numpy, which defines what kind of optimization to do: optimize{bool, list, tuple, ‘greedy’, ‘optimal’}, users can pass their custom contraction path as well.
If users pass contraction_path, we don't need to know static shapes. If users set to greedy/optimal (optimal should be default), we need to know. But we may find them later only. If they don't want optimize, then we don't need to obviously
— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/pytensor/pull/722#issuecomment-2222594999, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACCUPHF22H4SECYZCOLV3ZLZOBFAVCNFSM6AAAAABGPQHTOGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMRSGU4TIOJZHE . You are receiving this because you were mentioned.Message ID: @.***>
Some unrelated jax test failing, probably something that changed in a recent release? https://github.com/pymc-devs/pytensor/actions/runs/10161294793/job/28099514375?pr=722#step:6:778
Yes, I am also looking at this now. It's a jax bug that can be recreated easily:
import jax
jax.jit(jax.numpy.tri)(3, 3, 0)
We can ignore it. Looks like their _canonicalize_axis function is underflowing (something like np.uint32(-1) )
I think I fixed the tests (not the JAX one) and appeased mypy. @jessegrabowski docstrings extensions are left to you
Stopped force-pushing if you want to take over
Opened an issue here: https://github.com/google/jax/issues/22751
I'll hit the docstrings ASAP if that's all that's holding this up
Opened an issue here: https://github.com/google/jax/issues/22751
Great let's just mark it as xfail then
First pass on docstrings. Working on the doctests revealed two things:
- Our implementation of
_deltadoes not agree with that of JAX:
from jax._src.lax.lax import _delta as jax_delta
from pytensor.tensor.einsum import _delta as pt_delta
jax_delta(int, (3, 3, 3), (0,1))
Array([[[1, 1, 1],
[0, 0, 0],
[0, 0, 0]],
[[0, 0, 0],
[1, 1, 1],
[0, 0, 0]],
[[0, 0, 0],
[0, 0, 0],
[1, 1, 1]]], dtype=int32)
pt_delta((3,3,3), (0,1)).astype(int).eval()
array([[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]],
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]],
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]])
We seem to always output the axes=(1,2) case, regardless of what the requested axes were.
- Our
_general_dotfunction is discarding shape information somewhere, which doesn't seem right:
import pytensor.tensor as pt
from pytensor.tensor.einsum import _general_dot
A = pt.tensor(shape = (3, 4, 5))
B = pt.tensor(shape = (3, 5, 2))
result = _general_dot((A, B), axes=[[2], [1]], batch_axes=[[0], [0]])
print(result.type.shape)
(3, None, None)
Static shape is optional, not a requirement. In our case it probably has to do with the reshape introduced by tensordot and/or Blockwise which doesn't do any special shape inference shape (static or at rewrite) for core shapes.
That's something we probably want to address for Blockwise in the Numba backend
I understand it's optional, but it also shouldn't be discarded if available no?
We are not discarding anything on purpose but an intermediate op (or blockwise) doesn't know how to provide more precise output shape.
There can also be a tradeoff where quite some effort may be needed to figure out static shape that may not be worth it at define time. Anyway the main point is that it shouldn't be a blocker.
We can open an issue for whatever Op is losing the static shape and then assess if it's worth the cost or not