MathOptInterface.jl icon indicating copy to clipboard operation
MathOptInterface.jl copied to clipboard

Symbolic AD of ScalarNonlinearFunction

Open odow opened this issue 1 year ago • 2 comments

This has come up quite a few times, so I think we need this.

I don't know what the right API is. Perhaps:

MOI.Nonlinear.gradient(f) :: Dict{MOI.VariableIndex,MOI.ScalarNonlinearFunction}

The use case for this would be:

  • to replace Symbolics.jl in MathOptSymbolicAD
  • Simple cases like https://discourse.julialang.org/t/auto-gradient-with-symbolics-for-jump-constraints/117953

It's okay for this to have all the usual issues with symbolic AD.

odow avatar Aug 12 '24 22:08 odow

I started hacking something for this:

module SymbolicAD

import MacroTools
import MathOptInterface as MOI

derivative(::Real, ::MOI.VariableIndex) = false

function derivative(f::MOI.VariableIndex, x::MOI.VariableIndex)
    return ifelse(f == x, true, false)
end

function derivative(
    f::MOI.ScalarAffineFunction{T},
    x::MOI.VariableIndex,
) where {T}
    ret = zero(T)
    for term in f.terms
        if term.variable == x
            ret += term.coefficient
        end
    end
    return ret
end

function derivative(
    f::MOI.ScalarQuadraticFunction{T},
    x::MOI.VariableIndex,
) where {T}
    constant = zero(T)
    for term in f.affine_terms
        if term.variable == x
            constant += term.coefficient
        end
    end
    aff_terms = MOI.ScalarAffineTerm{T}[]
    for q_term in f.quadratic_terms
        if q_term.variable_1 == q_term.variable_2 == x
            push!(aff_terms, MOI.ScalarAffineTerm(q_term.coefficient, x))
        elseif q_term.variable_1 == x
            push!(
                aff_terms,
                MOI.ScalarAffineTerm(q_term.coefficient, q_term.variable_2),
            )
        elseif q_term.variable_2 == x
            push!(
                aff_terms,
                MOI.ScalarAffineTerm(q_term.coefficient, q_term.variable_1),
            )
        end
    end
    return MOI.ScalarAffineFunction(aff_terms, constant)
end

function derivative(f::MOI.ScalarNonlinearFunction, x::MOI.VariableIndex)
    if length(f.args) == 1
        u = only(f.args)
        if f.head == :+
            return derivative(u, x)
        elseif f.head == :-
            return MOI.ScalarNonlinearFunction(:-, Any[derivative(u, x)])
        elseif f.head == :abs
            scale = MOI.ScalarNonlinearFunction(
                :ifelse,
                Any[MOI.ScalarNonlinearFunction(:>=, Any[u, 0]), 1, -1],
            )
            return MOI.ScalarNonlinearFunction(:*, Any[scale, derivative(u, x)])
        elseif f.head == :sign
            return false
        end
        for (key, df, _) in MOI.Nonlinear.SYMBOLIC_UNIVARIATE_EXPRESSIONS
            if key == f.head
                # The chain rule: d(f(g(x))) / dx = f'(g(x)) * g'(x)
                u = only(f.args)
                df_du = MacroTools.postwalk(df) do node
                    if node === :x
                        return u
                    elseif Meta.isexpr(node, :call)
                        op, args = node.args[1], node.args[2:end]
                        return MOI.ScalarNonlinearFunction(op, args)
                    end
                    return node
                end
                du_dx = derivative(u, x)
                return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
            end
        end
    end
    if f.head == :+
        # d/dx(+(args...)) = +(d/dx args)
        args = Any[derivative(arg, x) for arg in f.args]
        return MOI.ScalarNonlinearFunction(:+, args)
    elseif f.head == :-
        # d/dx(-(args...)) = -(d/dx args)
        # Note that - is not unary here because that wouuld be caught above.
        args = Any[derivative(arg, x) for arg in f.args]
        return MOI.ScalarNonlinearFunction(:-, args)
    elseif f.head == :*
        # Product rule: d/dx(*(args...)) = sum(d{i}/dx * args\{i})
        sum_terms = Any[]
        for i in 1:length(f.args)
            g = MOI.ScalarNonlinearFunction(:*, copy(f.args))
            g.args[i] = derivative(f.args[i], x)
            push!(sum_terms, g)
        end
        return MOI.ScalarNonlinearFunction(:+, sum_terms)
    elseif f.head == :^
        @assert length(f.args) == 2
        u, p = f.args
        du_dx = derivative(u, x)
        dp_dx = derivative(p, x)
        if _iszero(dp_dx)
            # p is constant and does not depend on x
            df_du = MOI.ScalarNonlinearFunction(
                :*,
                Any[p, MOI.ScalarNonlinearFunction(:^, Any[u, p-1])],
            )
            du_dx = derivative(u, x)
            return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
        else
            # u(x)^p(x)
        end
    elseif f.head == :/
        # Quotient rule: d/dx(u / v) = (du/dx)*v - u*(dv/dx)) / v^2
        @assert length(f.args) == 2
        u, v = f.args
        du_dx, dv_dx = derivative(u, x), derivative(v, x)
        return MOI.ScalarNonlinearFunction(
            :/,
            Any[
                MOI.ScalarNonlinearFunction(
                    :-,
                    Any[
                        MOI.ScalarNonlinearFunction(:*, Any[du_dx, v]),
                        MOI.ScalarNonlinearFunction(:*, Any[u, dv_dx]),
                    ],
                ),
                MOI.ScalarNonlinearFunction(:^, Any[v, 2]),
            ],
        )
    elseif f.head == :ifelse
        @assert length(f.args) == 3
        # Pick the derivative of the active branch
        return MOI.ScalarNonlinearFunction(
            :ifelse,
            Any[f.args[1], derivative(f.args[2], x), derivative(f.args[3], x)],
        )
    elseif f.head == :atan
        # TODO
    elseif f.head == :min
        g = derivative(f.args[end], x)
        for i in length(f.args)-1:-1:1
            g = MOI.ScalarNonlinearFunction(
                :ifelse,
                Any[
                    MOI.ScalarNonlinearFunction(:(<=), Any[f.args[i], f]),
                    derivative(f.args[i], x),
                    g,
                ],
            )
        end
        return g
    elseif f.head == :max
        g = derivative(f.args[end], x)
        for i in length(f.args)-1:-1:1
            g = MOI.ScalarNonlinearFunction(
                :ifelse,
                Any[
                    MOI.ScalarNonlinearFunction(:(>=), Any[f.args[i], f]),
                    derivative(f.args[i], x),
                    g,
                ],
            )
        end
        return g
    elseif f.head in (:(>=), :(<=), :(<), :(>), :(==))
        return false
    end
    err = MOI.UnsupportedNonlinearOperator(
        f.head,
        "the operator does not support symbolic differentiation",
    )
    return throw(err)
end

simplify(f) = f

function simplify(f::MOI.ScalarAffineFunction{T}) where {T}
    f = MOI.Utilities.canonical(f)
    if isempty(f.terms)
        return f.constant
    end
    return f
end

function simplify(f::MOI.ScalarQuadraticFunction{T}) where {T}
    f = MOI.Utilities.canonical(f)
    if isempty(f.quadratic_terms)
        return simplify(MOI.ScalarAffineFunction(f.affine_terms, f.constant))
    end
    return f
end

function _eval_if_constant(f::MOI.ScalarNonlinearFunction)
    if all(_isnum, f.args) && hasproperty(Base, f.head)
        return getproperty(Base, f.head)(f.args...)
    end
    return f
end

_eval_if_constant(f) = f

function simplify(f::MOI.ScalarNonlinearFunction)
    for i in 1:length(f.args)
        f.args[i] = simplify(f.args[i])
    end
    return _eval_if_constant(simplify(Val(f.head), f))
end

simplify(::Val, f::MOI.ScalarNonlinearFunction) = f

_iszero(x::Union{Bool,Integer,Float64}) = iszero(x)
_iszero(::Any) = false

_isone(x::Union{Bool,Integer,Float64}) = isone(x)
_isone(::Any) = false

_isnum(::Union{Bool,Integer,Float64}) = true
_isnum(::Any) = false

_isexpr(::Any, ::Symbol, n::Int = 0) = false
_isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol) = f.head == head
function _isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol, n::Int)
    return _isexpr(f, head) && length(f.args) == n
end

function simplify(::Val{:*}, f::MOI.ScalarNonlinearFunction)
    new_args = Any[]
    first_constant = 0
    for arg in f.args
        if _isexpr(arg, :*)
            append!(new_args, arg.args)
        elseif _iszero(arg)
            return false
        elseif _isone(arg)
            # nothing
        elseif arg isa Real
            if first_constant == 0
                push!(new_args, arg)
                first_constant = length(new_args)
            else
                new_args[first_constant] *= arg
            end
        else
            push!(new_args, arg)
        end
    end
    if isempty(new_args)
        return true
    elseif length(new_args) == 1
        return only(new_args)
    end
    return MOI.ScalarNonlinearFunction(:*, new_args)
end

function simplify(::Val{:+}, f::MOI.ScalarNonlinearFunction)
    if length(f.args) == 1
        return only(f.args)
    elseif length(f.args) == 2 && _isexpr(f.args[2], :-, 1)
        return MOI.ScalarNonlinearFunction(
            :-,
            Any[f.args[1], f.args[2].args[1]],
        )
    end
    new_args = Any[]
    first_constant = 0
    for arg in f.args
        if _isexpr(arg, :+)
            append!(new_args, arg.args)
        elseif _iszero(arg)
            # nothing
        elseif arg isa Real
            if first_constant == 0
                push!(new_args, arg)
                first_constant = length(new_args)
            else
                new_args[first_constant] += arg
            end
        else
            push!(new_args, arg)
        end
    end
    if isempty(new_args)
        return false
    elseif length(new_args) == 1
        return only(new_args)
    end
    return MOI.ScalarNonlinearFunction(:+, new_args)
end

function simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
    if length(f.args) == 1
        if _isexpr(f.args[1], :-, 1)
            # -(-(x)) => x
            return f.args[1].args[1]
        end
    elseif length(f.args) == 2
        if _iszero(f.args[1])
            # 0 - x => -x
            return MOI.ScalarNonlinearFunction(:-, Any[f.args[2]])
        elseif _iszero(f.args[2])
            # x - 0 => x
            return f.args[1]
        elseif f.args[1] == f.args[2]
            # x - x => 0
            return false
        elseif _isexpr(f.args[2], :-, 1)
            # x - -(y) => x + y
            return MOI.ScalarNonlinearFunction(
                :+,
                Any[f.args[1], f.args[2].args[1]],
            )
        end
    end
    return f
end

function simplify(::Val{:^}, f::MOI.ScalarNonlinearFunction)
    if _iszero(f.args[2])
        # x^0 => 1
        return true
    elseif _isone(f.args[2])
        # x^1 => x
        return f.args[1]
    elseif _iszero(f.args[1])
        # 0^x => 0
        return false
    elseif _isone(f.args[1])
        # 1^x => 1
        return true
    end
    return f
end

function variables(f::MOI.AbstractScalarFunction)
    ret = MOI.VariableIndex[]
    variables!(ret, f)
    return ret
end

variables(::Real) = MOI.VariableIndex[]
variables!(ret, ::Real) = nothing

function variables!(ret, f::MOI.VariableIndex)
    if !(f in ret)
        push!(ret, f)
    end
    return
end

function variables!(ret, f::MOI.ScalarAffineTerm)
    if !(f.variable in ret)
        push!(ret, f.variable)
    end
    return
end

function variables!(ret, f::MOI.ScalarAffineFunction)
    for term in f.terms
        variables!(ret, term)
    end
    return
end

function variables!(ret, f::MOI.ScalarQuadraticTerm)
    if !(f.variable_1 in ret)
        push!(ret, f.variable_1)
    end
    if !(f.variable_2 in ret)
        push!(ret, f.variable_2)
    end
    return
end

function variables!(ret, f::MOI.ScalarQuadraticFunction)
    for term in f.affine_terms
        variables!(ret, term)
    end
    for q_term in f.quadratic_terms
        variables!(ret, q_term)
    end
    return
end

function variables!(ret, f::MOI.ScalarNonlinearFunction)
    for arg in f.args
        variables!(ret, arg)
    end
    return
end

gradient(::Real) = Dict{MOI.VariableIndex,Any}()
function gradient(f::MOI.AbstractScalarFunction)
    return Dict{MOI.VariableIndex,Any}(
        x => simplify(derivative(f, x)) for x in variables(f)
    )
end

end

using JuMP, Test

function test_derivative()
    model = Model()
    @variable(model, x)
    @variable(model, y)
    @variable(model, z)
    @testset "$f" for (f, fp) in Any[
        # derivative(::Real, ::MOI.VariableIndex)
        1.0=>0.0,
        1.23=>0.0,
        # derivative(f::MOI.VariableIndex, x::MOI.VariableIndex)
        x=>1.0,
        y=>0.0,
        # derivative(f::MOI.ScalarAffineFunction{T}, x::MOI.VariableIndex)
        1.0*x=>1.0,
        1.0*x+2.0=>1.0,
        2.0*x+2.0=>2.0,
        2.0*x+y+2.0=>2.0,
        2.0*x+y+z+2.0=>2.0,
        # derivative(f::MOI.ScalarQuadraticFunction{T}, x::MOI.VariableIndex)
        QuadExpr(1.0 * x)=>1.0,
        QuadExpr(1.0 * x + 0.0 * y)=>1.0,
        x*y=>1.0*y,
        y*x=>1.0*y,
        x^2=>2.0*x,
        x^2+3x+4=>2.0*x+3.0,
        (x-1.0)^2=>2.0*(x-1),
        (3*x+1.0)^2=>6.0*(3x+1),
        # Univariate
        #   f.head == :+
        @force_nonlinear(+x)=>1,
        @force_nonlinear(+sin(x))=>cos(x),
        #   f.head == :-
        @force_nonlinear(-sin(x))=>-cos(x),
        #   f.head == :abs
        @force_nonlinear(
            abs(sin(x))
        )=>op_ifelse(op_greater_than_or_equal_to(sin(x), 0), 1, -1)*cos(x),
        #   f.head == :sign
        sign(x)=>false,
        # SYMBOLIC_UNIVARIATE_EXPRESSIONS
        sin(x)=>cos(x),
        cos(x)=>-sin(x),
        log(x)=>1/x,
        log(2x)=>1/(2x)*2.0,
        # f.head == :+
        sin(x)+cos(x)=>cos(x)-sin(x),
        # f.head == :-
        sin(x)-cos(x)=>cos(x)+sin(x),
        # f.head == :*
        @force_nonlinear(*(x, y, z))=>@force_nonlinear(*(y, z)),
        @force_nonlinear(*(y, x, z))=>@force_nonlinear(*(y, z)),
        @force_nonlinear(*(y, z, x))=>@force_nonlinear(*(y, z)),
        # :^
        sin(x)^2=>@force_nonlinear(*(2.0, sin(x), cos(x))),
        sin(x)^1=>cos(x),
        # :/
        @force_nonlinear(/(x, 2))=>0.5,
        @force_nonlinear(
            x^2 / (x + 1)
        )=>@force_nonlinear((*(2, x, x + 1) - x^2) / (x + 1)^2),
        # :ifelse
        op_ifelse(z, x^2, x)=>op_ifelse(z, 2x, 1),
        # :atan
        # :min
        min(x, x^2)=>op_ifelse(op_less_than_or_equal_to(x, min(x, x^2)), 1, 2x),
        # :max
        max(
            x,
            x^2,
        )=>op_ifelse(op_greater_than_or_equal_to(x, max(x, x^2)), 1, 2x),
        # comparisons
        op_greater_than_or_equal_to(x, y)=>false,
        op_equal_to(x, y)=>false,
    ]
        g = SymbolicAD.derivative(moi_function(f), index(x))
        h = SymbolicAD.simplify(g)
        if !(h ≈ moi_function(fp))
            @show h
            @show f
        end
        @test h ≈ moi_function(fp)
    end
    return
end

function test_gradient()
    model = Model()
    @variable(model, x)
    @variable(model, y)
    @variable(model, z)
    @testset "$f" for (f, fp) in Any[
        # ::Real
        1.0=>Dict(),
        # ::AffExpr
        x=>Dict(x => 1),
        x+y=>Dict(x => 1, y => 1),
        2x+y=>Dict(x => 2, y => 1),
        2x+3y+1=>Dict(x => 2, y => 3),
        # ::QuadExpr
        2x^2+3y+z=>Dict(x => 4x, y => 3, z => 1),
        # ::NonlinearExpr
        sin(x)=>Dict(x => cos(x)),
        sin(x + y)=>Dict(x => cos(x + y), y => cos(x + y)),
        sin(x + 2y)=>Dict(x => cos(x + 2y), y => cos(x + 2y) * 2),
    ]
        g = SymbolicAD.gradient(moi_function(f))
        h = Dict{MOI.VariableIndex,Any}(
            index(k) => moi_function(v) for (k, v) in fp
        )
        @test length(g) == length(h)
        for k in keys(g)
            @test g[k] ≈ h[k]
        end
    end
    return
end

function test_simplify()
    model = Model()
    @variable(model, x)
    @variable(model, y)
    @variable(model, z)
    @testset "$f" for (f, fp) in Any[
        # simplify(f)
        x=>x,
        # simplify(f::MOI.ScalarAffineFunction{T})
        AffExpr(2.0)=>2.0,
        # simplify(f::MOI.ScalarQuadraticFunction{T})
        QuadExpr(x + 1)=>x+1,
        # simplify(f::MOI.ScalarNonlinearFunction)
        @force_nonlinear(sin(*(3, x^0)))=>sin(3),
        sin(log(x))=>sin(log(x)),
        op_ifelse(z, x, 0)=>op_ifelse(z, x, 0),
        # simplify(::Val{:*}, f::MOI.ScalarNonlinearFunction)
        @force_nonlinear(*(x, *(y, z)))=>@force_nonlinear(*(x, y, z)),
        @force_nonlinear(
            *(x, *(y, z, *(x, 2)))
        )=>@force_nonlinear(*(x, y, z, x, 2)),
        @force_nonlinear(*(x, 3, 2))=>@force_nonlinear(*(x, 6)),
        @force_nonlinear(*(3, x, 2))=>@force_nonlinear(*(6, x)),
        @force_nonlinear(*(x, 1))=>x,
        @force_nonlinear(*(-(x, x), 1))=>0,
        # simplify(::Val{:+}, f::MOI.ScalarNonlinearFunction)
        @force_nonlinear(+(x, +(y, z)))=>@force_nonlinear(+(x, y, z)),
        +(sin(x), -cos(x))=>sin(x)-cos(x),
        @force_nonlinear(+(x, 1, 2))=>@force_nonlinear(+(x, 3)),
        @force_nonlinear(+(1, x, 2))=>@force_nonlinear(+(3, x)),
        @force_nonlinear(+(x, 0))=>x,
        @force_nonlinear(+(-(x, x), 0))=>0,
        # simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
        @force_nonlinear(-(-(x)))=>x,
        @force_nonlinear(-(x, 0))=>x,
        @force_nonlinear(-(0, x))=>@force_nonlinear(-x),
        @force_nonlinear(-(x, x))=>0,
        @force_nonlinear(-(x, -y))=>@force_nonlinear(x + y),
        @force_nonlinear(-(x, y))=>@force_nonlinear(x - y),
        # simplify(::Val{:^}, f::MOI.ScalarNonlinearFunction)
        @force_nonlinear(^(x, 0))=>1,
        @force_nonlinear(^(x, 1))=>x,
        @force_nonlinear(^(0, x))=>0,
        @force_nonlinear(^(1, x))=>1,
        x^y=>x^y,
    ]
        g = SymbolicAD.simplify(moi_function(f))
        if !(g ≈ moi_function(fp))
            @show f
            @show g
        end
        @test g ≈ moi_function(fp)
    end
    return
end

function test_variable()
    model = Model()
    @variable(model, x)
    @variable(model, y)
    @variable(model, z)
    @testset "$f" for (f, fp) in Any[
        # ::Real
        1.0=>[],
        # ::VariableRef,
        x=>[x],
        # ::AffExpr
        AffExpr(2.0)=>[],
        x+1=>[x],
        2x+1=>[x],
        2x+y+1=>[x, y],
        y+1+z=>[y, z],
        # ::QuadExpr
        zero(QuadExpr)=>[],
        QuadExpr(x + 1)=>[x],
        QuadExpr(x + 1 + y)=>[x, y],
        x^2=>[x],
        x^2+x=>[x],
        x^2+y=>[y, x],
        x*y=>[x, y],
        y*x=>[y, x],
        # ::NonlinearExpr
        sin(x)=>[x],
        sin(x + y)=>[x, y],
        sin(x)*cos(y)=>[x, y],
    ]
        @test SymbolicAD.variables(moi_function(f)) == index.(fp)
    end
    return
end

@testset "SymbolicAD" begin
    @testset "derivative" begin
        test_derivative()
    end
    @testset "simplify" begin
        test_simplify()
    end
    @testset "variable" begin
        test_variable()
    end
    @testset "gradient" begin
        test_gradient()
    end
end

nothing

I think the trick for integrating this into MathOptSymbolicAD is to have an efficient interpreter that re-uses expression values across the primal and derivatives evaluation. The symbolic expression trees are always going to be fundamentally limited.

odow avatar Sep 03 '24 23:09 odow

Thinking on this, I should probably merge this first into MathOptSymbolicAD.jl, get it working, and then we can add MathOptSymbolicAD as MOI.Nonlinear.SymbolicAD.

odow avatar Sep 11 '24 08:09 odow

Closing in favor of MathOptSymboliAD for now.

odow avatar Jan 09 '25 00:01 odow