Ghost.jl icon indicating copy to clipboard operation
Ghost.jl copied to clipboard

`Vararg`s get clobbered?

Open darsnack opened this issue 3 years ago • 6 comments

I have the following code:

f(x, xs...) = max.(x, xs...)

where max is treated like a primitive. When I trace f(rand(2, 2), rand(2, 2)) with with Ghost, I get

Tape{Dict{Any, Any}}
  inp %1::typeof(f)
  inp %2::Matrix{Float64}
  inp %3::Matrix{Float64}
  %4 = tuple(max, %2)::Tuple{typeof(max), Matrix{Float64}}
  %5 = _apply_iterate(iterate, broadcasted, %4, %3)::Broadcasted{}
  %6 = materialize(%5)::Matrix{Float64}

Compare this to @code_lowered:

CodeInfo(
1 ─ %1 = Core.tuple(Main.max, x)
│   %2 = Core._apply_iterate(Base.iterate, Base.broadcasted, %1, xs)
│   %3 = Base.materialize(%2)
└──      return %3
)

The call in @code_lowered makes sense: %1 and xs are both going to splat correctly. But in the tape, %3 is not going to splat correctly, because it refers to the input matrix instead of the intermediate Vararg.

darsnack avatar Dec 05 '21 17:12 darsnack

An even simpler MWE:

julia> f(x, xs...) = getindex.(x, xs...)
f (generic function with 1 method)

julia> Ghost.trace(f, X, 1)[2]
Tape{Dict{Any, Any}}
  inp %1::typeof(f)
  inp %2::Matrix{Float64}
  inp %3::Int64
  %4 = tuple(getindex, %2)::Tuple{typeof(getindex), Matrix{Float64}}
  %5 = _apply_iterate(iterate, broadcasted, %4, %3)::Broadcasted{}
  %6 = materialize(%5)::Matrix{Float64}

julia> @code_lowered f(X, 1)
CodeInfo(
1 ─ %1 = Core.tuple(Main.getindex, x)
│   %2 = Core._apply_iterate(Base.iterate, Base.broadcasted, %1, xs)
│   %3 = Base.materialize(%2)
└──      return %3
)

darsnack avatar Dec 05 '21 17:12 darsnack

After looking into how to fix this, it seems like an issue when the top-level function being called has a Vararg signature. Since Ghost.trace(f, args...) already uses splatting, we would need to correctly de-sugar args into the signature that f expects in order for the tracer to "see" the splatting in f?

darsnack avatar Dec 05 '21 17:12 darsnack

Busy right now, but seems to be the same as dfdx/Yota.jl#84

dfdx avatar Dec 05 '21 18:12 dfdx

Yes, it's the same issue as the one I linked. Unfortunately, I don't have an immediate fix for it - IRTools adds some magic that I don't know how to mitigate. Recently I looked at Mixtape.jl and CodeInfoTools.jl as a more future-proof alternative to IRTools.jl, but refactoring would take quite a lot of time.

Usually, wrapping a top-level function into another without varargs solves the issue. Would it be sufficient for your current use case?

dfdx avatar Dec 05 '21 22:12 dfdx

In my case, I managed to avoid the Vararg. But my package that uses Ghost is supposed to support any user defined function. So eventual support for Vararg would be nice.

darsnack avatar Dec 10 '21 14:12 darsnack

Agree. My current plan is to take a look at various alternatives to IRTools, which is bottleneck in this issue and to be deprecated anyway, and see what would be easier - update to a new tech stack or dive into internals of IRTools and fix it in the current version. Unfortunately, both options are quite complicated, so no estimated time for resolution yet :disappointed:

dfdx avatar Dec 11 '21 22:12 dfdx