OMEinsum.jl icon indicating copy to clipboard operation
OMEinsum.jl copied to clipboard

Getting type stability with EinCode

Open sethaxen opened this issue 5 years ago • 11 comments

I have a function that computes the product of a square matrix along one dimension of an n-dimensional array. Thus, the returned array is of the same size as the passed array. Because the dimension over which to multiply is only known at runtime, I use EinCode. However, the result is not type-stable. Is there a good way to give OMEinsum more information so the compiler can figure out the return type? Or maybe more generally, what's the best way to contract over a single index shared by two arrays, where the index is only known at runtime?

julia> using OMEinsum

julia> function f(M::AbstractMatrix, V::AbstractArray; dim=1)
           n = ndims(V)
           dimsV = Tuple(Base.OneTo(n))
           dimsY = Base.setindex(dimsV, 0, dim)
           dimsM = (0, dim)
           code = EinCode((dimsV, dimsM), dimsY)
           return einsum(code, (V, M))
       end

julia> M, V = randn(4, 4), randn(10, 4, 2);

julia> f(M, V; dim=2)
10×4×2 Array{Float64,3}:
[:, :, 1] =
  1.19814   -2.83308   -0.82374    6.23831
 -3.856      1.35973    0.168978   1.15039
  3.60948   -2.782     -0.735527   1.44291
 -4.52866    0.361779   0.807384   3.24125
  2.74821    1.30956    1.20418   -5.25221
  4.45576   -0.632032  -1.40112   -5.93926
 -2.1384     0.81895    0.187812  -1.01684
  4.51044   -1.39046   -0.798984  -3.6388
 -0.987397  -0.393374  -1.85841   -0.326891
 -3.02511    2.97092    2.33957   -3.35689

[:, :, 2] =
  1.9988    -2.7311     -2.85731     3.38059
 -5.63312    2.61159     3.5489      7.22906
  1.58536    0.74342    -0.0612845  -5.44578
  0.957018  -0.0174554   0.838485    0.054773
  1.81001   -1.62433    -0.753998    0.165946
  2.69391   -0.0213057  -1.24054    -6.89847
  3.61053   -2.85339    -1.76307    -1.98227
  4.4069    -0.590834    0.724681    0.698118
 -5.60072    1.33233     1.42462     4.45287
 -2.31928   -0.103913    1.75607     7.84296

julia> using Test

julia> @inferred f(M, V; dim=2)
ERROR: return type Array{Float64,3} does not match inferred return type Any
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope at REPL[57]:1

sethaxen avatar May 04 '20 08:05 sethaxen

If you comment out the last line of f, then its return type is EinCode{_A,_B} where _B where _A -- so I don't know how much hope there is of the final type being stable.

I think the work here is ultimately done by TensorOperations, which keeps dimensions and strides as values not types. So this is stable:

julia> function f4(M, V; dim)
           IA = (-1,0)
           IB = ntuple(d -> d==dim ? 0 : d, ndims(V))
           # IC = (-1, filter(!=(dim), ntuple(+, ndims(V)))...)
           IC = ntuple(d -> d==dim ? -1 : d, ndims(V))
           TensorOperations.tensorcontract(M, IA, V, IB, IC)
       end
f4 (generic function with 1 method)

julia> f4(M, V; dim=2) ≈ f(M, V, dim=2)
true

julia> @code_warntype  f4(M, V; dim=2)
...
Body::Array{Float64,3}

mcabbott avatar Jun 08 '20 18:06 mcabbott