Metatheory.jl icon indicating copy to clipboard operation
Metatheory.jl copied to clipboard

why Fixpoint rewriter does not work as expected?

Open overshiki opened this issue 2 years ago • 0 comments

Hi, Recently, I tried to use the combination of rewriters to achieve the idea of rewriting until no changes. More specifically, I use Fixpoint(Postwalk(PassThrough(r))) to traverse an Expr, hoping a cancel rule will be applied one by one. In comparison, I also implemented a version using Egraph. I expected both of the methods will work, however, only Egraph version does. The piece of code is as below:


using Metatheory
using Metatheory: Prewalk, Postwalk, PassThrough, Chain, Fixpoint

abstract type Param end

struct Posi <: Param end
struct Nega <: Param end

struct Model 
    param::Param
end

function to_cancel(x::Model, y::Model)
    if x.param isa Posi && y.param isa Nega 
        return 0 
    end 

    if x.param isa Nega && y.param isa Posi
        return 0 
    end 

    return :($x + $y)
end


function cancel_rewriter()
    r1 = @rule x y x::Model + y::Model => to_cancel(x, y)
    r2 = @rule a 0 + a::Model --> a::Model 
    r3 = @rule a a::Model + 0 --> a::Model
    r = Chain([r1, r2, r3])
    r = Fixpoint(Postwalk(PassThrough(r)))
    return r
end

function egraph_rules()
    v = AbstractRule[]
    t = @theory x y a begin
        x::Model + y::Model => to_cancel(x, y)
        0 + a::Model --> a::Model
        a::Model + 0 --> a::Model
    end
    append!(v, t)
    return v 
end

function egraph_rewriter(circ, v)
    g = EGraph(circ)
    params = SaturationParams(timeout=10, eclasslimit=40000)
    report = saturate!(g, v, params)
    circ = extract!(g, astsize)
    return circ
end 

import Base.(+)
function (+)(a::Expr, b::Expr)
    circ = :((+)())
    append!(circ.args, a.args[2:end])
    append!(circ.args, b.args[2:end])
    return circ
end

function (+)(a::Expr, b::Model)
    circ = :((+)())
    append!(circ.args, a.args[2:end])
    push!(circ.args, b)
    return circ
end

posi = Model(Posi())
nega = Model(Nega())
expr = :((+)()) + posi + nega + posi + nega + posi + nega
r = cancel_rewriter()
@show expr
@show r(expr)

v = egraph_rules()
nexpr = egraph_rewriter(expr, v)
@show nexpr

println()
expr = :((+)()) + posi + nega
@show expr
@show r(expr)

the result is as below:

expr = :(Model(Posi()) + Model(Nega()) + Model(Posi()) + Model(Nega()) + Model(Posi()) + Model(Nega()))
r(expr) = :(Model(Posi()) + Model(Nega()) + Model(Posi()) + Model(Nega()) + Model(Posi()) + Model(Nega()))
nexpr = 0

expr = :(Model(Posi()) + Model(Nega()))
r(expr) = 0

As you can see, both the cancel_rewriter and egraph_rewriter using the same set of rules, the only difference is that cancel_rewriter use Fixpoint(Postwalk(PassThrough(r))) mechanism, while egraph_rewriter use Egraph. Moreover, if we only have two terms(expr = :(Model(Posi()) + Model(Nega()))), the cancel_rewriter works well.
It is not surprising that egraph_rewriter works, the problem is why cancel_rewriter does not? Maybe my understanding of Fixpoint is not correct?

overshiki avatar Apr 11 '22 04:04 overshiki