SymbolicUtils.jl
SymbolicUtils.jl copied to clipboard
Use Mjolnir for tracing
This is just a prototype for the time being (Mjolnir is not registered yet) but it shows off the basic ideas. This is the right solution to #16 (handling dispatch correctly), #8 (handling typeof etc. correctly) and #67 (handling ifelse and if/else).
julia> using SymbolicUtils
julia> @symbolic 3x^2 + 2y + 1
1 + (3 * (x ^ 2)) + (2 * y)
julia> f(x::Real) = abs2(Complex(x, 2x))
f (generic function with 1 method)
julia> @symbolic f(x)+y
(5 * (x ^ 2)) + y
Demo of #8:
julia> _qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))
_qreltype (generic function with 1 method)
julia> @symbolic zero(_qreltype(typeof(x)))
0.0
ifelse works:
julia> relu(x::Real) = ifelse(x == 0, x, zero(x))
relu (generic function with 2 methods)
julia> @symbolic relu(x)
ifelse(x == 0, x, 0.0)
I forgot to mention that there's a syntax for variable types:
julia> @symbolic relu(x::Int64)
ifelse(x == 0, x, 0)
Mjolnir is also happy tracing through array/linalg code; I'm not sure what this package's support for arrays looks like, but if there's an example I'd be happy to demo something via Mjolnir.
Codecov Report
Merging #78 into master will decrease coverage by
2.46%. The diff coverage is0.00%.
@@ Coverage Diff @@
## master #78 +/- ##
==========================================
- Coverage 89.81% 87.34% -2.47%
==========================================
Files 8 9 +1
Lines 432 490 +58
==========================================
+ Hits 388 428 +40
- Misses 44 62 +18
| Impacted Files | Coverage Δ | |
|---|---|---|
| src/SymbolicUtils.jl | 100.00% <ø> (ø) |
|
| src/trace.jl | 0.00% <0.00%> (ø) |
|
| src/rulesets.jl | 50.00% <0.00%> (ø) |
|
| src/matchers.jl | 90.09% <0.00%> (+0.30%) |
:arrow_up: |
| src/rule_dsl.jl | 98.87% <0.00%> (+1.54%) |
:arrow_up: |
| src/simplify.jl | 95.08% <0.00%> (+2.90%) |
:arrow_up: |
| src/types.jl | 89.58% <0.00%> (+2.91%) |
:arrow_up: |
| src/methods.jl | 96.29% <0.00%> (+7.40%) |
:arrow_up: |
| src/util.jl | 76.19% <0.00%> (+9.52%) |
:arrow_up: |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact),ø = not affected,? = missing dataPowered by Codecov. Last update 5a8f2f9...fd1e2af. Read the comment docs.
Absolutely fantanstic!! I'm excited.
It might be good to add a Symutils -> IR conversion given a vector of arguments.
Amazing, Mike! This is a really promising direction.
Glad you like it!
It might be good to add a Symutils -> IR conversion given a vector of arguments.
Yeah, that seems useful – especially for ultimately turning a Term into something you can evaluate.
Mjolnir is also happy tracing through array/linalg code; I'm not sure what this package's support for arrays looks like, but if there's an example I'd be happy to demo something via Mjolnir.
this is something we need to start adding methods for... I guess right now an array of symbols is all that works. I think we can use something like Mjolnir's shaped arrays. It could be stored as T in Symbolic{T}.
Wow, fantastic. Thanks, Mike!
julia> using SymbolicUtils
julia> @syms a b c
(a, b, c)
julia> ir=IR(a+b+c, [a,b]; mod=Main)
1: (%1, %2)
%3 = (+)(%1, %2, Main.c)
return %3
julia> h=func(ir)
##260 (generic function with 1 method)
julia> h(2,3)
5 + c
Nice! Does this sort of thing have world age issues? Like can ModelingToolkit use this function instead of going through GeneralizedGenerated?
made the macro use values from the current env (if symbols are passed, uses the symtype to trace, everything else is a Const).
this lets you call stuff with anything:
julia> f(y, x) = y .+ x
julia> @symbolic f([1,2], a)
broadcasted(+, [1, 2], a)
julia> @symbolic f([b,c], a)
broadcasted(+, [b, c], a)
julia> @syms a b::Float64
(a, b)
julia> f(x, y) = x > 1 ? y+2 : y
f (generic function with 2 methods)
julia> @symbolic f(2, a)
a + 2
julia> f(x::Float64, y) = float(y)
f (generic function with 2 methods)
julia> @symbolic f(b, a)
float(a)
world age
julia> function make_and_call(expr, args, vals)
func(IR(expr, args))(vals...)
end
julia> make_and_call(a+b, [a,b], [1,2])
ERROR: MethodError: no method matching ##261(::Int64, ::Int64)
The applicable method may be too new: running in world age 27244, while current world is 27246.
Closest candidates are:
##261(::Any, ::Any) at /home/shashi/.julia/dev/IRTools/src/eval.jl:18 (method too new to be called from this world context.)
Stacktrace:
[1] make_and_call(::SymbolicUtils.Term{Number}, ::Array{SymbolicUtils.Sym,1}, ::Array{Int64,1}) at ./REPL[91]:2
[2] top-level scope at REPL[92]:1
[3] run_backend(::REPL.REPLBackend) at /home/shashi/.julia/packages/Revise/MgvIv/src/Revise.jl:1023
[4] top-level scope at none:0
lol I guess conversion back to IR is only useful to rewrite functions with n=1 basic block right now.
using SymbolicUtils
using ModelingToolkit
using ModelingToolkit: expand_derivatives, to_mtk
using SymbolicUtils: to_symbolic
function D(f, T; simplify=true)
@syms t()::T
@derivatives DD'~to_mtk(t())
expr = @symbolic f(t())
deriv_expr = to_symbolic(expand_derivatives(DD(to_mtk(expr)), simplify))
@show deriv_expr
IR(deriv_expr, [t()])
end
julia> f(x::Float64) = sin(cos(x)) - cos(sin(x))
f (generic function with 1 method)
julia> D(f, Float64, simplify=false)
deriv_expr = (one(cos(sin(t()))) * sin(sin(t())) * cos(t())) + (-1 * one(sin(cos(t()))) * cos(cos(t())) * sin(t()))
1: (%1)
%2 = (cos)(%1)
%3 = (sin)(%2)
%4 = (one)(%3)
%5 = (cos)(%1)
%6 = (cos)(%5)
%7 = (sin)(%1)
%8 = (-)(%7)
%9 = (*)(%8, 1)
%10 = (*)(%6, %9)
%11 = (*)(%4, %10)
%12 = (sin)(%1)
%13 = (cos)(%12)
%14 = (one)(%13)
%15 = (-)(%14)
%16 = (sin)(%1)
%17 = (sin)(%16)
%18 = (-)(%17)
%19 = (cos)(%1)
%20 = (*)(%19, 1)
%21 = (*)(%18, %20)
%22 = (*)(%15, %21)
%23 = (+)(%11, %22)
return %23
julia> D(f, Float64, simplify=true)
deriv_expr = (sin(sin(t())) * cos(t())) + (-1 * cos(cos(t())) * sin(t()))
1: (%1)
%2 = (sin)(%1)
%3 = (sin)(%2)
%4 = (cos)(%1)
%5 = (*)(%3, %4)
%6 = (cos)(%1)
%7 = (cos)(%6)
%8 = (sin)(%1)
%9 = (*)(-1, %7, %8)
%10 = (+)(%5, %9)
return %10
tracing AD with simplification! 2nd derivative goes from 78 lines -> 20 lines after simplify
Nice! Does this sort of thing have world age issues? Like can ModelingToolkit use this function instead of going through GeneralizedGenerated?
func just calls eval on the expression, so it has all the usual issues that does.
I think the right way to solve this is to use a generated function which does a trace based on input types; and it should also add backedges from all traced functions to solve 265-y issues.
@MikeInnes I just added a fuzzer for Mjolnir if you're interested in having a look (only constructs simple functions from SymbolicUtils exprs):
julia> include("test/fuzzlib.jl")
fuzz_test (generic function with 1 method)
julia> for i=1:500; fuzz_test(0, num_spec, mjolnir=true); end
err = Mjolnir.TraceError(ErrorException("No IR for Tuple{Core.IntrinsicFunction,Int64,Int64}"), Any[(const(#63),), (const(//), const(-91), const(33)), (const(Rational), const(-91), const(33)), (const(Rational{Int64}), const(-91), const(33)), (const(divgcd), const(-91), const(33)), (const(div), const(-91), const(1)), (const(checked_sdiv_int), const(-91), const(1))])
function ()
-91//33
end
err = Mjolnir.TraceError(ErrorException("No IR for Tuple{Core.IntrinsicFunction,Int64,Int64}"), Any[(const(#71), Real), (const(//), const(39), const(1)), (const(Rational), const(39), const(1)), (const(Rational{Int64}), const(39), const(1)), (const(divgcd), const(39), const(1)), (const(div), const(39), const(1)), (const(checked_sdiv_int), const(39), const(1))])
function (b,)
(39//1 - 29//36*im) - b
end
...
It might sometimes hit #30 but I will fix that.