SymbolicUtils.jl
SymbolicUtils.jl copied to clipboard
Common subexpression elimination
done with @dpsanders :
using SymbolicUtils
using SymbolicUtils: Sym, Term
using SymbolicUtils.Rewriters
using DataStructures
newsym() = Sym{Number}(gensym("cse"))
function cse(expr)
dict = OrderedDict()
r = @rule ~x::(x -> x isa Term) => haskey(dict, ~x) ? dict[~x] : dict[~x] = gensym()
final = Postwalk(Chain([r]))(expr)
[[var=>ex for (ex, var) in pairs(dict)]..., final]
end
Examples:
@syms x y
julia> cse(cos(cos(x)) + sin(cos(x)))
5-element Array{Any,1}:
Symbol("##260") => cos(x)
Symbol("##261") => cos(##260)
Symbol("##262") => sin(##260)
Symbol("##263") => ##261 + ##262
Symbol("##263")
julia> SymbolicUtils.show_simplified[] = false
false
julia> cse(cos(cos(x)) + sin(cos(x)))
5-element Array{Any,1}:
Symbol("##264") => cos(x)
Symbol("##265") => cos(##264)
Symbol("##266") => sin(##264)
Symbol("##267") => ##265 + ##266
Symbol("##267")
julia> cse(cos(cos(x)) + cos(cos(x)))
4-element Array{Any,1}:
Symbol("##268") => cos(x)
Symbol("##269") => cos(##268)
Symbol("##270") => ##269 + ##269
Symbol("##270")
This issue is a good place to think about some API questions:
- What should CSE return
- Should we be able to place that thing in a different Term?
So, this is essentially SSA form, right? What if we made a struct SSATerm (or CSETerm if you prefer) that behaves as if it were a Term but actually stores it's contents in this manner.
This would certainly make interop with things like Mjolnir.jl and IRTools.jl easier.
behaves as if it were a
Term
The interface for this is:
operation(t)::Function
arguments(t)::Vector
I can imagine operation(t) being something like function block end.
But arguments(t) will need to contain assignment operations. So I'm not quite sure how to represent assignments in the world view of terms. I don't think = as the head of the term captures the same meaning as changing an environment.
But, it's perfectly possible to convert this thing to a Term when needed.
So I think we can get most of what we need by adding this as a top-level feature separate from terms, but with the opt-in conversion. Call it BasicBlock or something. And when we have the ability to do conditionals, we can add more such nodes and compose BasicBlocks to form more complex pieces of computation.
Hi @shashi , I had modified this to work in my context, but it's no longer working. Could something have changed in the last merge that broke it?
My setup is like this:
function cse(s::Symbolic)
vars = atoms(s)
dict = OrderedDict()
r = @rule ~x => csestep(~x, vars, dict)
final = RW.Postwalk(RW.PassThrough(r))(s)
[[var=>ex for (ex, var) in pairs(dict)]...]
end
export csestep
csestep(s::Sym, vars, dict) = s
csestep(s, vars, dict) = s
function csestep(x::S, vars, dict) where {S <: Symbolic}
# Avoid breaking local variables out of their scope
isempty(setdiff(atoms(x), vars)) || return x
if !haskey(dict, x)
dict[x] = Sym{symtype(x)}(gensym())
end
return dict[x]
end
Here atoms returns the free variables (maybe I should call it that instead, but atoms is shorter). I need this because I need it to know that a Sum is built from something like a lambda term (index -> value), and the index should never escape the sum.
This used to work, but now it stops when it gets to sin. Any idea what's going on?
I'm not sure, try Chain instead of PassThrough.
Thanks, I fixed it a while back:
julia> cse(cos(cos(x)) + sin(cos(x)))
4-element Vector{Pair{Symbol, SymbolicUtils.Symbolic{Number}}}:
Symbol("##473") => cos(x)
Symbol("##474") => cos(var"##473")
Symbol("##475") => sin(var"##473")
Symbol("##476") => var"##474" + var"##475"
julia> cse(cos(cos(x)) + cos(cos(x)))
3-element Vector{Pair{Symbol, SymbolicUtils.Symbolic{Number}}}:
Symbol("##477") => cos(x)
Symbol("##478") => cos(var"##477")
Symbol("##479") => 2var"##478"
Since I also have an implementation of this in ReversePropagation.jl, maybe we should pool our efforts and add this to Symbolics.jl?
Nice! I need to have another look at that package, I didn't realize you have CSE set up in it.
I'd love to have something like this for general-purpose use, as long as there's a way to represent free variables. For me this comes up in symbolic representation of sums, since you need to sharing the index variable doesn't really make sense. Ideally SymbolicUtils can have symbolic summations built in; my current implementation works but feels a little hacky.
@cscherrer Can you give an example of what you mean by symbolic summations?
Say you have this model:
julia> m = @model x begin
a ~ Normal()
b ~ Normal()
y ~ For(1:100000) do j Normal(μ = a + b*x[j]) end
return y
end;
and some fake data,
julia> x = randn(100000);
julia> y = rand(m(x=x));
Then the posterior is
Sum(-0.5(((getindex(y, var"##i1#739")) - a - (b*(getindex(x, getindex(UnitRange(1, 100000), var"##i1#739")))))^2), var"##i1#739", 1, 100000) - (0.5(a^2)) - (0.5(b^2))
we really don't want to expand this sum, since we'd have a ridiculous number of terms. But we can still apply some rules and then do some constant folding, so we end up with
-0.5(121986.42933250747 + (10139.425041642466a) + (100001(a^2)) + (100587.87406561377(b^2)) + (262.72135043120505a*b) - (94008.24079371039b))
which is then really fast to sample from.
Thanks, that's interesting. I haven't thought about that situation at all. I've just been trying to get simple scalar expressions to work!
add this to Symbolics.jl?
Naah! It should be here. See #200
Do you have an approach for tracking variable scope, and Lambda or Func expressions? Otherwise CSE has lots of corner cases
The one in #200 is very conservative, it does not go inside a Func yet. And it outputs a self-contained Let.