@tullio and generated code
Hi, thanks for this great package! I probably have more of a question than an issue with this package. I'm trying to generalize the following function from 3 to N, i.e. I want to contract a vararg number of matrices.
function contractouter(dst::AbstractArray{T,3}, xs::Vararg{AbstractMatrix{T}, 3}) where {T}
a,b,c = xs
@tullio dst[i,j,k] = a[i,z] * b[j,z] * c[k,z]
end
I tried using @generate and build up the expression myself.
@generated function contractouter(dst::AbstractArray{T,N}, xs::Vararg{AbstractMatrix{T}, N}) where {T, N}
xnames = Expr(:tuple)
for i in 1:N
push!(xnames.args, Symbol(:x, i))
end
inds = Expr(:tuple)
for i in 1:N
push!(inds.args, Symbol(:i, i))
end
prod_expr = nothing
for i in 1:N
xi = xnames.args[i]
ii = inds.args[i]
term = :($(Expr(:ref, xi, ii, :z)))
if prod_expr === nothing
prod_expr = term
else
prod_expr = :($prod_expr * $term)
end
end
index_expr = nothing
for i in 1:N
ii = inds.args[i]
term = :($(Expr(:ref, :dst, ii)))
if index_expr === nothing
index_expr = term
else
push!(index_expr.args, ii)
end
end
# @show index_expr, prod_expr
quote
$(Expr(:(=), xnames, :xs))
@tullio $index_expr = $prod_expr
return dst
end
end
But I always end up with this:
julia> M=rand(4,5,6); a=rand(4,10); b=rand(5,10); c=rand(6,10); contractouter(M, a, b, c)
ERROR: The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator.
Is @tullio not designed to be called from @generated functions or am I making a mistake somewhere ..
I'm glad if it's useful!
Unfortunately this is a known problem. @tullio defines various functions, and @generated does not allow that. See for example https://github.com/mcabbott/Tullio.jl/issues/11 and https://github.com/mcabbott/Tullio.jl/pull/4#issuecomment-643822313 . This was in part a choice -- making @tullio purely a macro (and not itself generated functions) made is much easier to work with on, when I was writing it.
The best work-around I can suggest is to define your function for various number of arguments. In practice many bad things will happen once you return a 16-array... so perhaps it's not such a limitation to define only up to some fixed maximum number of arguments.
for N in 2:5
arguments = [Symbol(:x, n) for n in 1:N]
indices = [Symbol(:i, n) for n in 1:N]
factors = map(arguments, indices) do x, i
:($x[$i, z])
end
rhs = :(*($(factors...)))
lhs = :(dest[$(indices...)])
ex = :(contractouter($(arguments...)) = @tullio $lhs := $rhs)
@show ex
eval(ex)
end