Getting type stability with EinCode
I have a function that computes the product of a square matrix along one dimension of an n-dimensional array. Thus, the returned array is of the same size as the passed array. Because the dimension over which to multiply is only known at runtime, I use EinCode. However, the result is not type-stable. Is there a good way to give OMEinsum more information so the compiler can figure out the return type? Or maybe more generally, what's the best way to contract over a single index shared by two arrays, where the index is only known at runtime?
julia> using OMEinsum
julia> function f(M::AbstractMatrix, V::AbstractArray; dim=1)
n = ndims(V)
dimsV = Tuple(Base.OneTo(n))
dimsY = Base.setindex(dimsV, 0, dim)
dimsM = (0, dim)
code = EinCode((dimsV, dimsM), dimsY)
return einsum(code, (V, M))
end
julia> M, V = randn(4, 4), randn(10, 4, 2);
julia> f(M, V; dim=2)
10×4×2 Array{Float64,3}:
[:, :, 1] =
1.19814 -2.83308 -0.82374 6.23831
-3.856 1.35973 0.168978 1.15039
3.60948 -2.782 -0.735527 1.44291
-4.52866 0.361779 0.807384 3.24125
2.74821 1.30956 1.20418 -5.25221
4.45576 -0.632032 -1.40112 -5.93926
-2.1384 0.81895 0.187812 -1.01684
4.51044 -1.39046 -0.798984 -3.6388
-0.987397 -0.393374 -1.85841 -0.326891
-3.02511 2.97092 2.33957 -3.35689
[:, :, 2] =
1.9988 -2.7311 -2.85731 3.38059
-5.63312 2.61159 3.5489 7.22906
1.58536 0.74342 -0.0612845 -5.44578
0.957018 -0.0174554 0.838485 0.054773
1.81001 -1.62433 -0.753998 0.165946
2.69391 -0.0213057 -1.24054 -6.89847
3.61053 -2.85339 -1.76307 -1.98227
4.4069 -0.590834 0.724681 0.698118
-5.60072 1.33233 1.42462 4.45287
-2.31928 -0.103913 1.75607 7.84296
julia> using Test
julia> @inferred f(M, V; dim=2)
ERROR: return type Array{Float64,3} does not match inferred return type Any
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] top-level scope at REPL[57]:1
If you comment out the last line of f, then its return type is EinCode{_A,_B} where _B where _A -- so I don't know how much hope there is of the final type being stable.
I think the work here is ultimately done by TensorOperations, which keeps dimensions and strides as values not types. So this is stable:
julia> function f4(M, V; dim)
IA = (-1,0)
IB = ntuple(d -> d==dim ? 0 : d, ndims(V))
# IC = (-1, filter(!=(dim), ntuple(+, ndims(V)))...)
IC = ntuple(d -> d==dim ? -1 : d, ndims(V))
TensorOperations.tensorcontract(M, IA, V, IB, IC)
end
f4 (generic function with 1 method)
julia> f4(M, V; dim=2) ≈ f(M, V, dim=2)
true
julia> @code_warntype f4(M, V; dim=2)
...
Body::Array{Float64,3}