pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add `einsum`

Open jessegrabowski opened this issue 1 year ago • 30 comments

Description

TODO:

  • [x] Support ellipsis (related to https://github.com/dgasmith/opt_einsum/issues/235)
  • [x] Exclude broadcastable dims for better perf (JAX does the same)
  • [x] Handle missing static shape information (default to left to right contraction?)
    • [x] Add rewrite for optimizing Einsum Ops when all inputs have known static shapes
    • [x] Add rewrite for inlining optimized Einsum
  • [x] Get rid of Blockwise Reshape
    • [x] Fix lingering infinite rewriting bug
  • [x] Decide on providing optimize kwarg
  • [x] Appease Mypy
  • [ ] Better docstrings (@jessegrabowski self-assigned)
  • [x] Fix failing tests (@ricardoV94 self-assigned)

Related Issue

  • [x] Closes #57

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

jessegrabowski avatar Apr 19 '24 16:04 jessegrabowski

Are the current tests failing suppose to fail?

zaxtax avatar Apr 20 '24 23:04 zaxtax

Looks like it's related to the changes to the __len__ method in variables.py. I'd suggest just reverting the change unless we really need it. It's a pretty low-level thing that would need a bit of work to figure out everything implicated.

jessegrabowski avatar Apr 20 '24 23:04 jessegrabowski

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 80.82%. Comparing base (28d9d4d) to head (e262999). Report is 1 commits behind head on main.

:exclamation: Current head e262999 differs from pull request most recent head 3fe3257

Please upload reports for the commit 3fe3257 to get more accurate results.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #722      +/-   ##
==========================================
- Coverage   80.89%   80.82%   -0.07%     
==========================================
  Files         169      164       -5     
  Lines       46977    46844     -133     
  Branches    11478    11457      -21     
==========================================
- Hits        38000    37862     -138     
+ Misses       6767     6734      -33     
- Partials     2210     2248      +38     
Files Coverage Δ
pytensor/compile/builders.py 77.45% <100.00%> (-10.98%) :arrow_down:
pytensor/link/jax/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/jax/dispatch/einsum.py 100.00% <100.00%> (ø)
pytensor/tensor/einsum.py 100.00% <100.00%> (ø)

... and 55 files with indirect coverage changes

codecov[bot] avatar Apr 21 '24 13:04 codecov[bot]

All cases except those requiring tensordot with batch dims not on the left are passing

We may need more tests soon enough

ricardoV94 avatar May 07 '24 18:05 ricardoV94

Can we reuse how numpy implements tensordot?

On Tue, 7 May 2024, 11:52 Ricardo Vieira, @.***> wrote:

All cases except those requiring tensordot with batch dims not on the left are passing

We may need more tests soon enough

— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/pytensor/pull/722#issuecomment-2099086859, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACCUL3W5UGQ73SNQMCO2LZBEOQDAVCNFSM6AAAAABGPQHTOGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAOJZGA4DMOBVHE . You are receiving this because your review was requested.Message ID: @.***>

zaxtax avatar May 07 '24 20:05 zaxtax

Can we reuse how numpy implements tensordot?

We already do that, but numpy doesn't have batched tensordot (except of course through einsum), but we already have batched tensordot working in this PR just not with arbitrary batch axis. Should just need some extra transposes to get the job done

ricardoV94 avatar May 07 '24 20:05 ricardoV94

Huhu convolutions via einsum work :D

ricardoV94 avatar May 08 '24 08:05 ricardoV94

Now Einsum also works with inputs with unknown static shape (unoptimized ofc). We can add a rewrite for when such Op is found with inputs that now have static shapes (this can be quite relevant in PyMC, when users use freeze_rv_and_dims on a model with mutable coords)

ricardoV94 avatar Jun 23 '24 18:06 ricardoV94

The ellipsis case is failing due to a bug in opt_einsum: https://github.com/dgasmith/opt_einsum/issues/235

ricardoV94 avatar Jul 05 '24 10:07 ricardoV94

Fix for the elipsis bug was merged yesterday. I installed from their main and the tests now pass, which is nice! I guess we will have to wait for them to cut a new release, though?

jessegrabowski avatar Jul 07 '24 04:07 jessegrabowski

Fix for the elipsis bug was merged yesterday. I installed from their main and the tests now pass, which is nice! I guess we will have to wait for them to cut a new release, though?

I think we can create dummy empty arrays with the shapes to avoid the error for the time being.

I don't know how fast they release

ricardoV94 avatar Jul 07 '24 07:07 ricardoV94

Also tests that are failing on CI are passing for me locally, so I don't really know what's going on with that.

jessegrabowski avatar Jul 07 '24 09:07 jessegrabowski

Also tests that are failing on CI are passing for me locally, so I don't really know what's going on with that.

I think one of the Blockwise reshape rewrites I added is buggy and causes infinite rewrite loop. Gotta investigate

ricardoV94 avatar Jul 07 '24 11:07 ricardoV94

I fixed the ellipsis by using dummy empty arrays with shape. Since all the functionality exists in numpy these days, I dropped the dependency on opt_einsum

ricardoV94 avatar Jul 07 '24 16:07 ricardoV94

Is there any opt in opt_einsum? I thought using that was value-add over plain numpy.

jessegrabowski avatar Jul 07 '24 16:07 jessegrabowski

Is there any opt in opt_einsum? I thought using that was value-add over plain numpy.

The opt functionality of opt_einsum was incorporated into numpy itself sometime ago (literally line by line)

ricardoV94 avatar Jul 07 '24 16:07 ricardoV94

Getting close to the finish line. We have to add the optimize flag (or at least decide that we do not want to do it) and do the rewrite for when a graph is rebuild with static shapes at a later point

ricardoV94 avatar Jul 10 '24 18:07 ricardoV94

Getting close to the finish line. We have to add the optimize flag (or at least decide that we do not want to do it) and do the rewrite for when a graph is rebuild with static shapes at a later point

I like the optimize option but we agree it only makes sense when you have static shapes right? It's not required for an initial implementation right?

zaxtax avatar Jul 10 '24 23:07 zaxtax

Not sure what you mean @zaxtax, I'm talking about allowing the "optimize" kwarg like there is in numpy, which defines what kind of optimization to do: optimize{bool, list, tuple, ‘greedy’, ‘optimal’}, users can pass their custom contraction path as well.

If users pass contraction_path, we don't need to know static shapes. If users set to greedy/optimal (optimal should be default), we need to know. But we may find them later only. If they don't want optimize, then we don't need to obviously

ricardoV94 avatar Jul 11 '24 10:07 ricardoV94

Oh I misunderstood!

On Thu, 11 Jul 2024, 12:38 Ricardo Vieira, @.***> wrote:

Not sure what you mean @zaxtax https://github.com/zaxtax, I'm talking about allowing the "optimize" kwarg like there is in numpy, which defines what kind of optimization to do: optimize{bool, list, tuple, ‘greedy’, ‘optimal’}, users can pass their custom contraction path as well.

If users pass contraction_path, we don't need to know static shapes. If users set to greedy/optimal (optimal should be default), we need to know. But we may find them later only. If they don't want optimize, then we don't need to obviously

— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/pytensor/pull/722#issuecomment-2222594999, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACCUPHF22H4SECYZCOLV3ZLZOBFAVCNFSM6AAAAABGPQHTOGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMRSGU4TIOJZHE . You are receiving this because you were mentioned.Message ID: @.***>

zaxtax avatar Jul 11 '24 12:07 zaxtax

Some unrelated jax test failing, probably something that changed in a recent release? https://github.com/pymc-devs/pytensor/actions/runs/10161294793/job/28099514375?pr=722#step:6:778

ricardoV94 avatar Jul 30 '24 11:07 ricardoV94

Yes, I am also looking at this now. It's a jax bug that can be recreated easily:

import jax
jax.jit(jax.numpy.tri)(3, 3, 0)

We can ignore it. Looks like their _canonicalize_axis function is underflowing (something like np.uint32(-1) )

jessegrabowski avatar Jul 30 '24 11:07 jessegrabowski

I think I fixed the tests (not the JAX one) and appeased mypy. @jessegrabowski docstrings extensions are left to you

ricardoV94 avatar Jul 30 '24 11:07 ricardoV94

Stopped force-pushing if you want to take over

ricardoV94 avatar Jul 30 '24 11:07 ricardoV94

Opened an issue here: https://github.com/google/jax/issues/22751

I'll hit the docstrings ASAP if that's all that's holding this up

jessegrabowski avatar Jul 30 '24 11:07 jessegrabowski

Opened an issue here: https://github.com/google/jax/issues/22751

Great let's just mark it as xfail then

ricardoV94 avatar Jul 30 '24 11:07 ricardoV94

First pass on docstrings. Working on the doctests revealed two things:

  1. Our implementation of _delta does not agree with that of JAX:
from jax._src.lax.lax import _delta as jax_delta
from pytensor.tensor.einsum import _delta as pt_delta
jax_delta(int, (3, 3, 3), (0,1))

Array([[[1, 1, 1],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [1, 1, 1],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [1, 1, 1]]], dtype=int32)


pt_delta((3,3,3), (0,1)).astype(int).eval()
array([[[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]],

       [[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]],

       [[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]]])

We seem to always output the axes=(1,2) case, regardless of what the requested axes were.

  1. Our _general_dot function is discarding shape information somewhere, which doesn't seem right:
        import pytensor.tensor as pt
        from pytensor.tensor.einsum import _general_dot
        A = pt.tensor(shape = (3, 4, 5))
        B = pt.tensor(shape = (3, 5, 2))

        result = _general_dot((A, B), axes=[[2], [1]], batch_axes=[[0], [0]])
        print(result.type.shape)

       (3, None, None)

jessegrabowski avatar Jul 31 '24 12:07 jessegrabowski

Static shape is optional, not a requirement. In our case it probably has to do with the reshape introduced by tensordot and/or Blockwise which doesn't do any special shape inference shape (static or at rewrite) for core shapes.

That's something we probably want to address for Blockwise in the Numba backend

ricardoV94 avatar Jul 31 '24 13:07 ricardoV94

I understand it's optional, but it also shouldn't be discarded if available no?

jessegrabowski avatar Jul 31 '24 13:07 jessegrabowski

We are not discarding anything on purpose but an intermediate op (or blockwise) doesn't know how to provide more precise output shape.

There can also be a tradeoff where quite some effort may be needed to figure out static shape that may not be worth it at define time. Anyway the main point is that it shouldn't be a blocker.

We can open an issue for whatever Op is losing the static shape and then assess if it's worth the cost or not

ricardoV94 avatar Jul 31 '24 13:07 ricardoV94