Tullio.jl icon indicating copy to clipboard operation
Tullio.jl copied to clipboard

@tullio and generated code

Open ceesb opened this issue 4 months ago • 1 comments

Hi, thanks for this great package! I probably have more of a question than an issue with this package. I'm trying to generalize the following function from 3 to N, i.e. I want to contract a vararg number of matrices.

function contractouter(dst::AbstractArray{T,3}, xs::Vararg{AbstractMatrix{T}, 3}) where {T}
  a,b,c = xs
  @tullio dst[i,j,k] = a[i,z] * b[j,z] * c[k,z]
end

I tried using @generate and build up the expression myself.

@generated function contractouter(dst::AbstractArray{T,N}, xs::Vararg{AbstractMatrix{T}, N}) where {T, N}
    xnames = Expr(:tuple)
    for i in 1:N
        push!(xnames.args, Symbol(:x, i))
    end

    inds = Expr(:tuple)
    for i in 1:N
        push!(inds.args, Symbol(:i, i))
    end

    prod_expr = nothing
    for i in 1:N
        xi = xnames.args[i]
        ii = inds.args[i]
        term = :($(Expr(:ref, xi, ii, :z)))
        if prod_expr === nothing
            prod_expr = term
        else
            prod_expr = :($prod_expr * $term)
        end
    end

    index_expr = nothing
    for i in 1:N
        ii = inds.args[i]
        term = :($(Expr(:ref, :dst, ii)))
        if index_expr === nothing
            index_expr = term
        else
            push!(index_expr.args, ii)
        end
    end

    # @show index_expr, prod_expr

    quote
        $(Expr(:(=), xnames, :xs))
        @tullio $index_expr = $prod_expr
        return dst
    end
end

But I always end up with this:

julia> M=rand(4,5,6); a=rand(4,10); b=rand(5,10); c=rand(6,10); contractouter(M, a, b, c)
ERROR: The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator.

Is @tullio not designed to be called from @generated functions or am I making a mistake somewhere ..

ceesb avatar Jul 29 '25 21:07 ceesb

I'm glad if it's useful!

Unfortunately this is a known problem. @tullio defines various functions, and @generated does not allow that. See for example https://github.com/mcabbott/Tullio.jl/issues/11 and https://github.com/mcabbott/Tullio.jl/pull/4#issuecomment-643822313 . This was in part a choice -- making @tullio purely a macro (and not itself generated functions) made is much easier to work with on, when I was writing it.

The best work-around I can suggest is to define your function for various number of arguments. In practice many bad things will happen once you return a 16-array... so perhaps it's not such a limitation to define only up to some fixed maximum number of arguments.

for N in 2:5
    arguments = [Symbol(:x, n) for n in 1:N]
    indices = [Symbol(:i, n) for n in 1:N]
    factors = map(arguments, indices) do x, i
      :($x[$i, z])
    end
    rhs = :(*($(factors...)))
    lhs = :(dest[$(indices...)])
    ex = :(contractouter($(arguments...)) = @tullio $lhs := $rhs)
    @show ex 
    eval(ex)
end

mcabbott avatar Jul 30 '25 01:07 mcabbott