SymbolicUtils.jl icon indicating copy to clipboard operation
SymbolicUtils.jl copied to clipboard

Use Mjolnir for tracing

Open MikeInnes opened this issue 5 years ago • 13 comments

This is just a prototype for the time being (Mjolnir is not registered yet) but it shows off the basic ideas. This is the right solution to #16 (handling dispatch correctly), #8 (handling typeof etc. correctly) and #67 (handling ifelse and if/else).

julia> using SymbolicUtils

julia> @symbolic 3x^2 + 2y + 1
1 + (3 * (x ^ 2)) + (2 * y)

julia> f(x::Real) = abs2(Complex(x, 2x))
f (generic function with 1 method)

julia> @symbolic f(x)+y
(5 * (x ^ 2)) + y

Demo of #8:

julia> _qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))
_qreltype (generic function with 1 method)

julia> @symbolic zero(_qreltype(typeof(x)))
0.0

MikeInnes avatar May 08 '20 11:05 MikeInnes

ifelse works:

julia> relu(x::Real) = ifelse(x == 0, x, zero(x))
relu (generic function with 2 methods)

julia> @symbolic relu(x)
ifelse(x == 0, x, 0.0)

I forgot to mention that there's a syntax for variable types:

julia> @symbolic relu(x::Int64)
ifelse(x == 0, x, 0)

Mjolnir is also happy tracing through array/linalg code; I'm not sure what this package's support for arrays looks like, but if there's an example I'd be happy to demo something via Mjolnir.

MikeInnes avatar May 08 '20 12:05 MikeInnes

Codecov Report

Merging #78 into master will decrease coverage by 2.46%. The diff coverage is 0.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #78      +/-   ##
==========================================
- Coverage   89.81%   87.34%   -2.47%     
==========================================
  Files           8        9       +1     
  Lines         432      490      +58     
==========================================
+ Hits          388      428      +40     
- Misses         44       62      +18     
Impacted Files Coverage Δ
src/SymbolicUtils.jl 100.00% <ø> (ø)
src/trace.jl 0.00% <0.00%> (ø)
src/rulesets.jl 50.00% <0.00%> (ø)
src/matchers.jl 90.09% <0.00%> (+0.30%) :arrow_up:
src/rule_dsl.jl 98.87% <0.00%> (+1.54%) :arrow_up:
src/simplify.jl 95.08% <0.00%> (+2.90%) :arrow_up:
src/types.jl 89.58% <0.00%> (+2.91%) :arrow_up:
src/methods.jl 96.29% <0.00%> (+7.40%) :arrow_up:
src/util.jl 76.19% <0.00%> (+9.52%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 5a8f2f9...fd1e2af. Read the comment docs.

codecov-io avatar May 08 '20 12:05 codecov-io

Absolutely fantanstic!! I'm excited.

It might be good to add a Symutils -> IR conversion given a vector of arguments.

shashi avatar May 08 '20 15:05 shashi

Amazing, Mike! This is a really promising direction.

MasonProtter avatar May 08 '20 16:05 MasonProtter

Glad you like it!

It might be good to add a Symutils -> IR conversion given a vector of arguments.

Yeah, that seems useful – especially for ultimately turning a Term into something you can evaluate.

MikeInnes avatar May 08 '20 16:05 MikeInnes

Mjolnir is also happy tracing through array/linalg code; I'm not sure what this package's support for arrays looks like, but if there's an example I'd be happy to demo something via Mjolnir.

this is something we need to start adding methods for... I guess right now an array of symbols is all that works. I think we can use something like Mjolnir's shaped arrays. It could be stored as T in Symbolic{T}.

shashi avatar May 08 '20 18:05 shashi

Wow, fantastic. Thanks, Mike!

YingboMa avatar May 08 '20 18:05 YingboMa

julia> using SymbolicUtils

julia> @syms a b c
(a, b, c)

julia> ir=IR(a+b+c, [a,b]; mod=Main)
1: (%1, %2)
  %3 = (+)(%1, %2, Main.c)
  return %3

julia> h=func(ir)
##260 (generic function with 1 method)

julia> h(2,3)
5 + c

Nice! Does this sort of thing have world age issues? Like can ModelingToolkit use this function instead of going through GeneralizedGenerated?

shashi avatar May 10 '20 16:05 shashi

made the macro use values from the current env (if symbols are passed, uses the symtype to trace, everything else is a Const).

this lets you call stuff with anything:

julia> f(y, x) = y .+ x

julia> @symbolic f([1,2], a)
broadcasted(+, [1, 2], a)

julia> @symbolic f([b,c], a)
broadcasted(+, [b, c], a)

julia> @syms a b::Float64
(a, b)

julia> f(x, y) =  x > 1 ? y+2 : y
f (generic function with 2 methods)

julia> @symbolic f(2, a)
a + 2

julia> f(x::Float64, y) =  float(y)
f (generic function with 2 methods)

julia> @symbolic f(b, a)
float(a)

shashi avatar May 10 '20 19:05 shashi

world age

julia> function make_and_call(expr, args, vals)
           func(IR(expr, args))(vals...)
       end

julia> make_and_call(a+b, [a,b], [1,2])
ERROR: MethodError: no method matching ##261(::Int64, ::Int64)
The applicable method may be too new: running in world age 27244, while current world is 27246.
Closest candidates are:
  ##261(::Any, ::Any) at /home/shashi/.julia/dev/IRTools/src/eval.jl:18 (method too new to be called from this world context.)
Stacktrace:
 [1] make_and_call(::SymbolicUtils.Term{Number}, ::Array{SymbolicUtils.Sym,1}, ::Array{Int64,1}) at ./REPL[91]:2
 [2] top-level scope at REPL[92]:1
 [3] run_backend(::REPL.REPLBackend) at /home/shashi/.julia/packages/Revise/MgvIv/src/Revise.jl:1023
 [4] top-level scope at none:0

lol I guess conversion back to IR is only useful to rewrite functions with n=1 basic block right now.

shashi avatar May 10 '20 19:05 shashi

using SymbolicUtils
using ModelingToolkit
using ModelingToolkit: expand_derivatives, to_mtk
using SymbolicUtils: to_symbolic
function D(f, T; simplify=true)
    @syms t()::T
    @derivatives DD'~to_mtk(t())
    expr = @symbolic f(t())
    deriv_expr = to_symbolic(expand_derivatives(DD(to_mtk(expr)), simplify))
    @show deriv_expr
    IR(deriv_expr, [t()])
end
julia> f(x::Float64) = sin(cos(x)) - cos(sin(x))
f (generic function with 1 method)
julia> D(f, Float64, simplify=false)
deriv_expr = (one(cos(sin(t()))) * sin(sin(t())) * cos(t())) + (-1 * one(sin(cos(t()))) * cos(cos(t())) * sin(t()))
1: (%1)
  %2 = (cos)(%1)
  %3 = (sin)(%2)
  %4 = (one)(%3)
  %5 = (cos)(%1)
  %6 = (cos)(%5)
  %7 = (sin)(%1)
  %8 = (-)(%7)
  %9 = (*)(%8, 1)
  %10 = (*)(%6, %9)
  %11 = (*)(%4, %10)
  %12 = (sin)(%1)
  %13 = (cos)(%12)
  %14 = (one)(%13)
  %15 = (-)(%14)
  %16 = (sin)(%1)
  %17 = (sin)(%16)
  %18 = (-)(%17)
  %19 = (cos)(%1)
  %20 = (*)(%19, 1)
  %21 = (*)(%18, %20)
  %22 = (*)(%15, %21)
  %23 = (+)(%11, %22)
  return %23
julia> D(f, Float64, simplify=true)
deriv_expr = (sin(sin(t())) * cos(t())) + (-1 * cos(cos(t())) * sin(t()))
1: (%1)
  %2 = (sin)(%1)
  %3 = (sin)(%2)
  %4 = (cos)(%1)
  %5 = (*)(%3, %4)
  %6 = (cos)(%1)
  %7 = (cos)(%6)
  %8 = (sin)(%1)
  %9 = (*)(-1, %7, %8)
  %10 = (+)(%5, %9)
  return %10

tracing AD with simplification! 2nd derivative goes from 78 lines -> 20 lines after simplify

shashi avatar May 10 '20 21:05 shashi

Nice! Does this sort of thing have world age issues? Like can ModelingToolkit use this function instead of going through GeneralizedGenerated?

func just calls eval on the expression, so it has all the usual issues that does.

I think the right way to solve this is to use a generated function which does a trace based on input types; and it should also add backedges from all traced functions to solve 265-y issues.

MikeInnes avatar May 11 '20 13:05 MikeInnes

@MikeInnes I just added a fuzzer for Mjolnir if you're interested in having a look (only constructs simple functions from SymbolicUtils exprs):

julia> include("test/fuzzlib.jl")
fuzz_test (generic function with 1 method)

julia> for i=1:500; fuzz_test(0, num_spec, mjolnir=true); end
err = Mjolnir.TraceError(ErrorException("No IR for Tuple{Core.IntrinsicFunction,Int64,Int64}"), Any[(const(#63),), (const(//), const(-91), const(33)), (const(Rational), const(-91), const(33)), (const(Rational{Int64}), const(-91), const(33)), (const(divgcd), const(-91), const(33)), (const(div), const(-91), const(1)), (const(checked_sdiv_int), const(-91), const(1))])
function ()
    -91//33
end

err = Mjolnir.TraceError(ErrorException("No IR for Tuple{Core.IntrinsicFunction,Int64,Int64}"), Any[(const(#71), Real), (const(//), const(39), const(1)), (const(Rational), const(39), const(1)), (const(Rational{Int64}), const(39), const(1)), (const(divgcd), const(39), const(1)), (const(div), const(39), const(1)), (const(checked_sdiv_int), const(39), const(1))])
function (b,)
    (39//1 - 29//36*im) - b
end
...

It might sometimes hit #30 but I will fix that.

shashi avatar May 11 '20 15:05 shashi