Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Forward over reverse for variadic function

Open michel2323 opened this issue 11 months ago • 6 comments

I'm trying to do forward over reverse over this function

# Rosenbrock
f(x...) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2

To achieve this I try to write a wrapper since I can't use Active for the Forward pass. An example is this:

# Wrapper to call with Vector, 1st element is output
function f!(inout::Vector{T}) where {T}
    inp = ntuple(length(inout) - 1) do i
        inout[i+1]
    end
    inout[1] = f(inp...)
    nothing
end

inout = [0.0, 1.0, 2.0]
f!(inout)
println("in = $(inout[2]) $(inout[3])")
println("out = $(inout[1])")
dinout = [0.0, 0.0, 0.0]
autodiff_deferred(ReverseWithPrimal, f!, Const, Duplicated(inout, dinout))

This crashes with

┌ Warning: Type does not have a definite number of fields
│   T = Tuple{Vararg{Float64}}
└ @ Enzyme ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
┌ Warning: Type does not have a definite number of fields
│   T = Tuple{Vararg{Float64}}
└ @ Enzyme ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
ERROR: LoadError: Enzyme execution failed.
Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate (true, true, iterate, f, 6, 6)

The other version is the one I sent on Slack. That one doesn't crash, but the adjoints are not updated, although the adjoints of the output is zeroed.

using Enzyme

# Rosenbrock
f(x...) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2

function f!(y::Ref{T}, x::Vararg{T,N}) where {T,N}
    t = f(x...)
    y[] = t
    nothing
end

function gradient(x::Vararg{T,N}) where {T,N}
    # tx = map(Active, x)
    # tx = ntuple(i -> Duplicated(x[i], 0.0), N)
    fx = x -> Duplicated(x, 0.0)
    tx = map(fx, x)
    y = Duplicated(Ref{T}(zero(T)), Ref{T}(one(T)))
    autodiff_deferred(
    ReverseWithPrimal, f!, Const, y, tx...)
    @show tx
    return getfield.(tx, :dval), y.val
end
@show g = gradient(1.0, 2.0)

michel2323 avatar Mar 07 '24 21:03 michel2323

f(inp...) this is type unstable since you take a vector and splat it into an unknown number of elements. Enzyme claims that we don't yet support that type unstable tuple operator.

make inout an ntuple? Alternatively mark f as inline perhaps.

wsmoses avatar Mar 07 '24 23:03 wsmoses

For the latter you cant use a duplicated of a float, it won't do what you expect (which is what's happening in you case. You have to make those active

wsmoses avatar Mar 07 '24 23:03 wsmoses

I can't because I need to apply forward over the reverse pass. Or at least I got an error indicating that.

Do you know whether there is a way to compute second-order derivatives of variadic functions with active number arguments using Enzyme? Or unclear?

I read the type unstable part in the documentation, but it did not error this time with the proper warning. As far as I understand, not every type unstable code fails. And I don't pass in temporary storage.

michel2323 avatar Mar 08 '24 02:03 michel2323

Ok. I think I see. I have to move the tuples out.

Thanks.

michel2323 avatar Mar 08 '24 02:03 michel2323

Applying forward over reverse shouldn't be the issue. For any reverse call (independent of how called, either directly, or in forward over reverse, etc) active is required for non-mutable state, duplicated for mutable.

So if you got an error, it may help to see what it said?

wsmoses avatar Mar 08 '24 02:03 wsmoses

Okay. Finally got that second-order. Sorry for removing previous posts.

using Enzyme

# Rosenbrock
@inline function f(x...)
    (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
end

@inline function f!(y, x...)
    y[1] = f(x...)
end

x = (1.0, 2.0)
y = zeros(1)
f!(y,x...)
y[1] = 0.0
ry = ones(1)
g = zeros(2)
rx = ntuple(2) do i
    Active(x[i])
end
function gradient!(g, y, ry, rx...)
    g .= autodiff_deferred(ReverseWithPrimal, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
    return nothing
end

gradient!(g, y,ry,rx...)


# FoR
y[1] = 0.0
dy = ones(1)
ry[1] = 1.0
dry = zeros(1)
drx = ntuple(2) do i
    Active(one(Float64))
end

tdrx= ntuple(2) do i
    Duplicated(rx[i], drx[i])
end
rx

fill!(g, 0.0)
dg = zeros(2)
autodiff(Forward, gradient!, Const, Duplicated(g,dg), Duplicated(y,dy), Duplicated(ry, dry), tdrx...)
# H * drx
h=dg

michel2323 avatar Mar 08 '24 16:03 michel2323

Here's what I came up with:

julia> import Enzyme

julia> f(x...) = log(sum(exp.(x)))
f (generic function with 1 method)

julia> function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
           g .= Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active.(x)...)[1]
           return
       end
∇f! (generic function with 1 method)

julia> function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
           direction(i) = ntuple(j -> Enzyme.Active(T(i == j)), N)
           hess = Enzyme.autodiff(
               Enzyme.Forward,
               (x...) -> Enzyme.autodiff_deferred(Enzyme.Reverse, f, x...)[1],
               Enzyme.BatchDuplicated.(Enzyme.Active.(x), ntuple(direction, N))...,
           )[1]
           for j in 1:N, i in 1:j
               H[j, i] = hess[j][i]
           end
           return
       end
∇²f! (generic function with 1 method)

julia> N = 3
3

julia> x, g, H = rand(N), fill(NaN, N), fill(NaN, N, N);

julia> f(x...)
1.7419820145927152

julia> ∇f!(g, x...)

julia> ∇²f!(H, x...)

julia> g
3-element Vector{Float64}:
 0.24224320758303503
 0.30782611265915005
 0.4499306797578149

julia> H
3×3 Matrix{Float64}:
  0.183561   NaN         NaN
 -0.0745688    0.213069  NaN
 -0.108993    -0.1385      0.247493

See https://github.com/jump-dev/JuMP.jl/pull/3712#issuecomment-1996184943

odow avatar Mar 14 '24 01:03 odow