Symbolic AD of ScalarNonlinearFunction
This has come up quite a few times, so I think we need this.
I don't know what the right API is. Perhaps:
MOI.Nonlinear.gradient(f) :: Dict{MOI.VariableIndex,MOI.ScalarNonlinearFunction}
The use case for this would be:
- to replace Symbolics.jl in MathOptSymbolicAD
- Simple cases like https://discourse.julialang.org/t/auto-gradient-with-symbolics-for-jump-constraints/117953
It's okay for this to have all the usual issues with symbolic AD.
I started hacking something for this:
module SymbolicAD
import MacroTools
import MathOptInterface as MOI
derivative(::Real, ::MOI.VariableIndex) = false
function derivative(f::MOI.VariableIndex, x::MOI.VariableIndex)
return ifelse(f == x, true, false)
end
function derivative(
f::MOI.ScalarAffineFunction{T},
x::MOI.VariableIndex,
) where {T}
ret = zero(T)
for term in f.terms
if term.variable == x
ret += term.coefficient
end
end
return ret
end
function derivative(
f::MOI.ScalarQuadraticFunction{T},
x::MOI.VariableIndex,
) where {T}
constant = zero(T)
for term in f.affine_terms
if term.variable == x
constant += term.coefficient
end
end
aff_terms = MOI.ScalarAffineTerm{T}[]
for q_term in f.quadratic_terms
if q_term.variable_1 == q_term.variable_2 == x
push!(aff_terms, MOI.ScalarAffineTerm(q_term.coefficient, x))
elseif q_term.variable_1 == x
push!(
aff_terms,
MOI.ScalarAffineTerm(q_term.coefficient, q_term.variable_2),
)
elseif q_term.variable_2 == x
push!(
aff_terms,
MOI.ScalarAffineTerm(q_term.coefficient, q_term.variable_1),
)
end
end
return MOI.ScalarAffineFunction(aff_terms, constant)
end
function derivative(f::MOI.ScalarNonlinearFunction, x::MOI.VariableIndex)
if length(f.args) == 1
u = only(f.args)
if f.head == :+
return derivative(u, x)
elseif f.head == :-
return MOI.ScalarNonlinearFunction(:-, Any[derivative(u, x)])
elseif f.head == :abs
scale = MOI.ScalarNonlinearFunction(
:ifelse,
Any[MOI.ScalarNonlinearFunction(:>=, Any[u, 0]), 1, -1],
)
return MOI.ScalarNonlinearFunction(:*, Any[scale, derivative(u, x)])
elseif f.head == :sign
return false
end
for (key, df, _) in MOI.Nonlinear.SYMBOLIC_UNIVARIATE_EXPRESSIONS
if key == f.head
# The chain rule: d(f(g(x))) / dx = f'(g(x)) * g'(x)
u = only(f.args)
df_du = MacroTools.postwalk(df) do node
if node === :x
return u
elseif Meta.isexpr(node, :call)
op, args = node.args[1], node.args[2:end]
return MOI.ScalarNonlinearFunction(op, args)
end
return node
end
du_dx = derivative(u, x)
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
end
end
end
if f.head == :+
# d/dx(+(args...)) = +(d/dx args)
args = Any[derivative(arg, x) for arg in f.args]
return MOI.ScalarNonlinearFunction(:+, args)
elseif f.head == :-
# d/dx(-(args...)) = -(d/dx args)
# Note that - is not unary here because that wouuld be caught above.
args = Any[derivative(arg, x) for arg in f.args]
return MOI.ScalarNonlinearFunction(:-, args)
elseif f.head == :*
# Product rule: d/dx(*(args...)) = sum(d{i}/dx * args\{i})
sum_terms = Any[]
for i in 1:length(f.args)
g = MOI.ScalarNonlinearFunction(:*, copy(f.args))
g.args[i] = derivative(f.args[i], x)
push!(sum_terms, g)
end
return MOI.ScalarNonlinearFunction(:+, sum_terms)
elseif f.head == :^
@assert length(f.args) == 2
u, p = f.args
du_dx = derivative(u, x)
dp_dx = derivative(p, x)
if _iszero(dp_dx)
# p is constant and does not depend on x
df_du = MOI.ScalarNonlinearFunction(
:*,
Any[p, MOI.ScalarNonlinearFunction(:^, Any[u, p-1])],
)
du_dx = derivative(u, x)
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
else
# u(x)^p(x)
end
elseif f.head == :/
# Quotient rule: d/dx(u / v) = (du/dx)*v - u*(dv/dx)) / v^2
@assert length(f.args) == 2
u, v = f.args
du_dx, dv_dx = derivative(u, x), derivative(v, x)
return MOI.ScalarNonlinearFunction(
:/,
Any[
MOI.ScalarNonlinearFunction(
:-,
Any[
MOI.ScalarNonlinearFunction(:*, Any[du_dx, v]),
MOI.ScalarNonlinearFunction(:*, Any[u, dv_dx]),
],
),
MOI.ScalarNonlinearFunction(:^, Any[v, 2]),
],
)
elseif f.head == :ifelse
@assert length(f.args) == 3
# Pick the derivative of the active branch
return MOI.ScalarNonlinearFunction(
:ifelse,
Any[f.args[1], derivative(f.args[2], x), derivative(f.args[3], x)],
)
elseif f.head == :atan
# TODO
elseif f.head == :min
g = derivative(f.args[end], x)
for i in length(f.args)-1:-1:1
g = MOI.ScalarNonlinearFunction(
:ifelse,
Any[
MOI.ScalarNonlinearFunction(:(<=), Any[f.args[i], f]),
derivative(f.args[i], x),
g,
],
)
end
return g
elseif f.head == :max
g = derivative(f.args[end], x)
for i in length(f.args)-1:-1:1
g = MOI.ScalarNonlinearFunction(
:ifelse,
Any[
MOI.ScalarNonlinearFunction(:(>=), Any[f.args[i], f]),
derivative(f.args[i], x),
g,
],
)
end
return g
elseif f.head in (:(>=), :(<=), :(<), :(>), :(==))
return false
end
err = MOI.UnsupportedNonlinearOperator(
f.head,
"the operator does not support symbolic differentiation",
)
return throw(err)
end
simplify(f) = f
function simplify(f::MOI.ScalarAffineFunction{T}) where {T}
f = MOI.Utilities.canonical(f)
if isempty(f.terms)
return f.constant
end
return f
end
function simplify(f::MOI.ScalarQuadraticFunction{T}) where {T}
f = MOI.Utilities.canonical(f)
if isempty(f.quadratic_terms)
return simplify(MOI.ScalarAffineFunction(f.affine_terms, f.constant))
end
return f
end
function _eval_if_constant(f::MOI.ScalarNonlinearFunction)
if all(_isnum, f.args) && hasproperty(Base, f.head)
return getproperty(Base, f.head)(f.args...)
end
return f
end
_eval_if_constant(f) = f
function simplify(f::MOI.ScalarNonlinearFunction)
for i in 1:length(f.args)
f.args[i] = simplify(f.args[i])
end
return _eval_if_constant(simplify(Val(f.head), f))
end
simplify(::Val, f::MOI.ScalarNonlinearFunction) = f
_iszero(x::Union{Bool,Integer,Float64}) = iszero(x)
_iszero(::Any) = false
_isone(x::Union{Bool,Integer,Float64}) = isone(x)
_isone(::Any) = false
_isnum(::Union{Bool,Integer,Float64}) = true
_isnum(::Any) = false
_isexpr(::Any, ::Symbol, n::Int = 0) = false
_isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol) = f.head == head
function _isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol, n::Int)
return _isexpr(f, head) && length(f.args) == n
end
function simplify(::Val{:*}, f::MOI.ScalarNonlinearFunction)
new_args = Any[]
first_constant = 0
for arg in f.args
if _isexpr(arg, :*)
append!(new_args, arg.args)
elseif _iszero(arg)
return false
elseif _isone(arg)
# nothing
elseif arg isa Real
if first_constant == 0
push!(new_args, arg)
first_constant = length(new_args)
else
new_args[first_constant] *= arg
end
else
push!(new_args, arg)
end
end
if isempty(new_args)
return true
elseif length(new_args) == 1
return only(new_args)
end
return MOI.ScalarNonlinearFunction(:*, new_args)
end
function simplify(::Val{:+}, f::MOI.ScalarNonlinearFunction)
if length(f.args) == 1
return only(f.args)
elseif length(f.args) == 2 && _isexpr(f.args[2], :-, 1)
return MOI.ScalarNonlinearFunction(
:-,
Any[f.args[1], f.args[2].args[1]],
)
end
new_args = Any[]
first_constant = 0
for arg in f.args
if _isexpr(arg, :+)
append!(new_args, arg.args)
elseif _iszero(arg)
# nothing
elseif arg isa Real
if first_constant == 0
push!(new_args, arg)
first_constant = length(new_args)
else
new_args[first_constant] += arg
end
else
push!(new_args, arg)
end
end
if isempty(new_args)
return false
elseif length(new_args) == 1
return only(new_args)
end
return MOI.ScalarNonlinearFunction(:+, new_args)
end
function simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
if length(f.args) == 1
if _isexpr(f.args[1], :-, 1)
# -(-(x)) => x
return f.args[1].args[1]
end
elseif length(f.args) == 2
if _iszero(f.args[1])
# 0 - x => -x
return MOI.ScalarNonlinearFunction(:-, Any[f.args[2]])
elseif _iszero(f.args[2])
# x - 0 => x
return f.args[1]
elseif f.args[1] == f.args[2]
# x - x => 0
return false
elseif _isexpr(f.args[2], :-, 1)
# x - -(y) => x + y
return MOI.ScalarNonlinearFunction(
:+,
Any[f.args[1], f.args[2].args[1]],
)
end
end
return f
end
function simplify(::Val{:^}, f::MOI.ScalarNonlinearFunction)
if _iszero(f.args[2])
# x^0 => 1
return true
elseif _isone(f.args[2])
# x^1 => x
return f.args[1]
elseif _iszero(f.args[1])
# 0^x => 0
return false
elseif _isone(f.args[1])
# 1^x => 1
return true
end
return f
end
function variables(f::MOI.AbstractScalarFunction)
ret = MOI.VariableIndex[]
variables!(ret, f)
return ret
end
variables(::Real) = MOI.VariableIndex[]
variables!(ret, ::Real) = nothing
function variables!(ret, f::MOI.VariableIndex)
if !(f in ret)
push!(ret, f)
end
return
end
function variables!(ret, f::MOI.ScalarAffineTerm)
if !(f.variable in ret)
push!(ret, f.variable)
end
return
end
function variables!(ret, f::MOI.ScalarAffineFunction)
for term in f.terms
variables!(ret, term)
end
return
end
function variables!(ret, f::MOI.ScalarQuadraticTerm)
if !(f.variable_1 in ret)
push!(ret, f.variable_1)
end
if !(f.variable_2 in ret)
push!(ret, f.variable_2)
end
return
end
function variables!(ret, f::MOI.ScalarQuadraticFunction)
for term in f.affine_terms
variables!(ret, term)
end
for q_term in f.quadratic_terms
variables!(ret, q_term)
end
return
end
function variables!(ret, f::MOI.ScalarNonlinearFunction)
for arg in f.args
variables!(ret, arg)
end
return
end
gradient(::Real) = Dict{MOI.VariableIndex,Any}()
function gradient(f::MOI.AbstractScalarFunction)
return Dict{MOI.VariableIndex,Any}(
x => simplify(derivative(f, x)) for x in variables(f)
)
end
end
using JuMP, Test
function test_derivative()
model = Model()
@variable(model, x)
@variable(model, y)
@variable(model, z)
@testset "$f" for (f, fp) in Any[
# derivative(::Real, ::MOI.VariableIndex)
1.0=>0.0,
1.23=>0.0,
# derivative(f::MOI.VariableIndex, x::MOI.VariableIndex)
x=>1.0,
y=>0.0,
# derivative(f::MOI.ScalarAffineFunction{T}, x::MOI.VariableIndex)
1.0*x=>1.0,
1.0*x+2.0=>1.0,
2.0*x+2.0=>2.0,
2.0*x+y+2.0=>2.0,
2.0*x+y+z+2.0=>2.0,
# derivative(f::MOI.ScalarQuadraticFunction{T}, x::MOI.VariableIndex)
QuadExpr(1.0 * x)=>1.0,
QuadExpr(1.0 * x + 0.0 * y)=>1.0,
x*y=>1.0*y,
y*x=>1.0*y,
x^2=>2.0*x,
x^2+3x+4=>2.0*x+3.0,
(x-1.0)^2=>2.0*(x-1),
(3*x+1.0)^2=>6.0*(3x+1),
# Univariate
# f.head == :+
@force_nonlinear(+x)=>1,
@force_nonlinear(+sin(x))=>cos(x),
# f.head == :-
@force_nonlinear(-sin(x))=>-cos(x),
# f.head == :abs
@force_nonlinear(
abs(sin(x))
)=>op_ifelse(op_greater_than_or_equal_to(sin(x), 0), 1, -1)*cos(x),
# f.head == :sign
sign(x)=>false,
# SYMBOLIC_UNIVARIATE_EXPRESSIONS
sin(x)=>cos(x),
cos(x)=>-sin(x),
log(x)=>1/x,
log(2x)=>1/(2x)*2.0,
# f.head == :+
sin(x)+cos(x)=>cos(x)-sin(x),
# f.head == :-
sin(x)-cos(x)=>cos(x)+sin(x),
# f.head == :*
@force_nonlinear(*(x, y, z))=>@force_nonlinear(*(y, z)),
@force_nonlinear(*(y, x, z))=>@force_nonlinear(*(y, z)),
@force_nonlinear(*(y, z, x))=>@force_nonlinear(*(y, z)),
# :^
sin(x)^2=>@force_nonlinear(*(2.0, sin(x), cos(x))),
sin(x)^1=>cos(x),
# :/
@force_nonlinear(/(x, 2))=>0.5,
@force_nonlinear(
x^2 / (x + 1)
)=>@force_nonlinear((*(2, x, x + 1) - x^2) / (x + 1)^2),
# :ifelse
op_ifelse(z, x^2, x)=>op_ifelse(z, 2x, 1),
# :atan
# :min
min(x, x^2)=>op_ifelse(op_less_than_or_equal_to(x, min(x, x^2)), 1, 2x),
# :max
max(
x,
x^2,
)=>op_ifelse(op_greater_than_or_equal_to(x, max(x, x^2)), 1, 2x),
# comparisons
op_greater_than_or_equal_to(x, y)=>false,
op_equal_to(x, y)=>false,
]
g = SymbolicAD.derivative(moi_function(f), index(x))
h = SymbolicAD.simplify(g)
if !(h ≈ moi_function(fp))
@show h
@show f
end
@test h ≈ moi_function(fp)
end
return
end
function test_gradient()
model = Model()
@variable(model, x)
@variable(model, y)
@variable(model, z)
@testset "$f" for (f, fp) in Any[
# ::Real
1.0=>Dict(),
# ::AffExpr
x=>Dict(x => 1),
x+y=>Dict(x => 1, y => 1),
2x+y=>Dict(x => 2, y => 1),
2x+3y+1=>Dict(x => 2, y => 3),
# ::QuadExpr
2x^2+3y+z=>Dict(x => 4x, y => 3, z => 1),
# ::NonlinearExpr
sin(x)=>Dict(x => cos(x)),
sin(x + y)=>Dict(x => cos(x + y), y => cos(x + y)),
sin(x + 2y)=>Dict(x => cos(x + 2y), y => cos(x + 2y) * 2),
]
g = SymbolicAD.gradient(moi_function(f))
h = Dict{MOI.VariableIndex,Any}(
index(k) => moi_function(v) for (k, v) in fp
)
@test length(g) == length(h)
for k in keys(g)
@test g[k] ≈ h[k]
end
end
return
end
function test_simplify()
model = Model()
@variable(model, x)
@variable(model, y)
@variable(model, z)
@testset "$f" for (f, fp) in Any[
# simplify(f)
x=>x,
# simplify(f::MOI.ScalarAffineFunction{T})
AffExpr(2.0)=>2.0,
# simplify(f::MOI.ScalarQuadraticFunction{T})
QuadExpr(x + 1)=>x+1,
# simplify(f::MOI.ScalarNonlinearFunction)
@force_nonlinear(sin(*(3, x^0)))=>sin(3),
sin(log(x))=>sin(log(x)),
op_ifelse(z, x, 0)=>op_ifelse(z, x, 0),
# simplify(::Val{:*}, f::MOI.ScalarNonlinearFunction)
@force_nonlinear(*(x, *(y, z)))=>@force_nonlinear(*(x, y, z)),
@force_nonlinear(
*(x, *(y, z, *(x, 2)))
)=>@force_nonlinear(*(x, y, z, x, 2)),
@force_nonlinear(*(x, 3, 2))=>@force_nonlinear(*(x, 6)),
@force_nonlinear(*(3, x, 2))=>@force_nonlinear(*(6, x)),
@force_nonlinear(*(x, 1))=>x,
@force_nonlinear(*(-(x, x), 1))=>0,
# simplify(::Val{:+}, f::MOI.ScalarNonlinearFunction)
@force_nonlinear(+(x, +(y, z)))=>@force_nonlinear(+(x, y, z)),
+(sin(x), -cos(x))=>sin(x)-cos(x),
@force_nonlinear(+(x, 1, 2))=>@force_nonlinear(+(x, 3)),
@force_nonlinear(+(1, x, 2))=>@force_nonlinear(+(3, x)),
@force_nonlinear(+(x, 0))=>x,
@force_nonlinear(+(-(x, x), 0))=>0,
# simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
@force_nonlinear(-(-(x)))=>x,
@force_nonlinear(-(x, 0))=>x,
@force_nonlinear(-(0, x))=>@force_nonlinear(-x),
@force_nonlinear(-(x, x))=>0,
@force_nonlinear(-(x, -y))=>@force_nonlinear(x + y),
@force_nonlinear(-(x, y))=>@force_nonlinear(x - y),
# simplify(::Val{:^}, f::MOI.ScalarNonlinearFunction)
@force_nonlinear(^(x, 0))=>1,
@force_nonlinear(^(x, 1))=>x,
@force_nonlinear(^(0, x))=>0,
@force_nonlinear(^(1, x))=>1,
x^y=>x^y,
]
g = SymbolicAD.simplify(moi_function(f))
if !(g ≈ moi_function(fp))
@show f
@show g
end
@test g ≈ moi_function(fp)
end
return
end
function test_variable()
model = Model()
@variable(model, x)
@variable(model, y)
@variable(model, z)
@testset "$f" for (f, fp) in Any[
# ::Real
1.0=>[],
# ::VariableRef,
x=>[x],
# ::AffExpr
AffExpr(2.0)=>[],
x+1=>[x],
2x+1=>[x],
2x+y+1=>[x, y],
y+1+z=>[y, z],
# ::QuadExpr
zero(QuadExpr)=>[],
QuadExpr(x + 1)=>[x],
QuadExpr(x + 1 + y)=>[x, y],
x^2=>[x],
x^2+x=>[x],
x^2+y=>[y, x],
x*y=>[x, y],
y*x=>[y, x],
# ::NonlinearExpr
sin(x)=>[x],
sin(x + y)=>[x, y],
sin(x)*cos(y)=>[x, y],
]
@test SymbolicAD.variables(moi_function(f)) == index.(fp)
end
return
end
@testset "SymbolicAD" begin
@testset "derivative" begin
test_derivative()
end
@testset "simplify" begin
test_simplify()
end
@testset "variable" begin
test_variable()
end
@testset "gradient" begin
test_gradient()
end
end
nothing
I think the trick for integrating this into MathOptSymbolicAD is to have an efficient interpreter that re-uses expression values across the primal and derivatives evaluation. The symbolic expression trees are always going to be fundamentally limited.
Thinking on this, I should probably merge this first into MathOptSymbolicAD.jl, get it working, and then we can add MathOptSymbolicAD as MOI.Nonlinear.SymbolicAD.
Closing in favor of MathOptSymboliAD for now.