MathOptInterface.jl icon indicating copy to clipboard operation
MathOptInterface.jl copied to clipboard

Investigate using ChainRules for registered functions

Open odow opened this issue 2 years ago • 1 comments

Discussed with @oxinabox

  • https://github.com/JuliaDiff/ChainRules.jl
  • Used in ReverseDiff https://github.com/JuliaDiff/ReverseDiff.jl/blob/df0067465ca436b05377054ac7a184da691db1a3/src/macros.jl#L298-L316

odow avatar Jun 18 '22 01:06 odow

Main blockers for this are that it's limited to first-order derivatives, so we'd need nested AD to compute second-order.

I might be missing something, but it also seems setup to compute directional derivatives only, so we'd need multiple calls to compute the gradient of a vector-valued function.

odow avatar Jul 20 '22 01:07 odow

I took a look at this, and I don't see a need to integrate this into JuMP/MOI. It's not hard to write in user-space:

using JuMP
import ChainRules
import Ipopt
import Statistics

mutable struct Cache{F,N}
    f::F
    y::Float64
    dx::Vector{Float64}
    x::NTuple{N,Float64}
    Cache(f, N) = new{typeof(f),N}(f, NaN, fill(NaN, N), ntuple(i -> NaN, N))
end

function update(cache, x...)
    if x != cache.x
        cache.y, pullback = ChainRules.rrule(cache.f, collect(x))
        _, cache.dx = pullback(1.0)    
        cache.x = x
    end
    return
end

function f(cache, x...)
    update(cache, x...)
    return cache.y
end

function ∇f(cache, g, x...)
    update(cache, x...)
    g .= cache.dx
    return
end

model = Model(Ipopt.Optimizer)
c = Cache(Statistics.var, 3)
register(model, :f, 3, (x...) -> f(c, x...), (g, x...) -> ∇f(c, g, x...))
@variable(model, 0 <= x[i=1:3] <= 1 + i / 5)
@NLobjective(model, Min, f(x...))
@constraint(model, sum(x) >= 4)
optimize!(model)

odow avatar Aug 22 '22 23:08 odow