Forward over reverse for variadic function

11 months ago

I'm trying to do forward over reverse over this function

# Rosenbrock
f(x...) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2

To achieve this I try to write a wrapper since I can't use Active for the Forward pass. An example is this:

# Wrapper to call with Vector, 1st element is output
function f!(inout::Vector{T}) where {T}
    inp = ntuple(length(inout) - 1) do i
    inout[1] = f(inp...)

inout = [0.0, 1.0, 2.0]
println("in = $(inout[2]) $(inout[3])")
println("out = $(inout[1])")
dinout = [0.0, 0.0, 0.0]
autodiff_deferred(ReverseWithPrimal, f!, Const, Duplicated(inout, dinout))

This crashes with

┌ Warning: Type does not have a definite number of fields
│   T = Tuple{Vararg{Float64}}
└ @ Enzyme ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
┌ Warning: Type does not have a definite number of fields
│   T = Tuple{Vararg{Float64}}
└ @ Enzyme ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
ERROR: LoadError: Enzyme execution failed.
Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate (true, true, iterate, f, 6, 6)

The other version is the one I sent on Slack. That one doesn't crash, but the adjoints are not updated, although the adjoints of the output is zeroed.

using Enzyme

# Rosenbrock
f(x...) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2

function f!(y::Ref{T}, x::Vararg{T,N}) where {T,N}
    t = f(x...)
    y[] = t

function gradient(x::Vararg{T,N}) where {T,N}
    # tx = map(Active, x)
    # tx = ntuple(i -> Duplicated(x[i], 0.0), N)
    fx = x -> Duplicated(x, 0.0)
    tx = map(fx, x)
    y = Duplicated(Ref{T}(zero(T)), Ref{T}(one(T)))
    ReverseWithPrimal, f!, Const, y, tx...)
    @show tx
    return getfield.(tx, :dval), y.val
@show g = gradient(1.0, 2.0)

michel2323 Mar 07 '24

f(inp...) this is type unstable since you take a vector and splat it into an unknown number of elements. Enzyme claims that we don't yet support that type unstable tuple operator.

make inout an ntuple? Alternatively mark f as inline perhaps.

wsmoses Mar 07 '24

For the latter you cant use a duplicated of a float, it won't do what you expect (which is what's happening in you case. You have to make those active

wsmoses Mar 07 '24

I can't because I need to apply forward over the reverse pass. Or at least I got an error indicating that.

Do you know whether there is a way to compute second-order derivatives of variadic functions with active number arguments using Enzyme? Or unclear?

I read the type unstable part in the documentation, but it did not error this time with the proper warning. As far as I understand, not every type unstable code fails. And I don't pass in temporary storage.

michel2323 Mar 08 '24

Ok. I think I see. I have to move the tuples out.


michel2323 Mar 08 '24

Applying forward over reverse shouldn't be the issue. For any reverse call (independent of how called, either directly, or in forward over reverse, etc) active is required for non-mutable state, duplicated for mutable.

So if you got an error, it may help to see what it said?

wsmoses Mar 08 '24

Okay. Finally got that second-order. Sorry for removing previous posts.

using Enzyme

# Rosenbrock
@inline function f(x...)
    (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2

@inline function f!(y, x...)
    y[1] = f(x...)

x = (1.0, 2.0)
y = zeros(1)
y[1] = 0.0
ry = ones(1)
g = zeros(2)
rx = ntuple(2) do i
function gradient!(g, y, ry, rx...)
    g .= autodiff_deferred(ReverseWithPrimal, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
    return nothing

gradient!(g, y,ry,rx...)

# FoR
y[1] = 0.0
dy = ones(1)
ry[1] = 1.0
dry = zeros(1)
drx = ntuple(2) do i

tdrx= ntuple(2) do i
    Duplicated(rx[i], drx[i])

fill!(g, 0.0)
dg = zeros(2)
autodiff(Forward, gradient!, Const, Duplicated(g,dg), Duplicated(y,dy), Duplicated(ry, dry), tdrx...)
# H * drx

michel2323 Mar 08 '24

Here's what I came up with:

julia> import Enzyme

julia> f(x...) = log(sum(exp.(x)))
f (generic function with 1 method)

julia> function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
           g .= Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active.(x)...)[1]
∇f! (generic function with 1 method)

julia> function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
           direction(i) = ntuple(j -> Enzyme.Active(T(i == j)), N)
           hess = Enzyme.autodiff(
               (x...) -> Enzyme.autodiff_deferred(Enzyme.Reverse, f, x...)[1],
               Enzyme.BatchDuplicated.(Enzyme.Active.(x), ntuple(direction, N))...,
           for j in 1:N, i in 1:j
               H[j, i] = hess[j][i]
∇²f! (generic function with 1 method)

julia> N = 3

julia> x, g, H = rand(N), fill(NaN, N), fill(NaN, N, N);

julia> f(x...)

julia> ∇f!(g, x...)

julia> ∇²f!(H, x...)

julia> g
3-element Vector{Float64}:

julia> H
3×3 Matrix{Float64}:
  0.183561   NaN         NaN
 -0.0745688    0.213069  NaN
 -0.108993    -0.1385      0.247493

odow Mar 14 '24