Metatheory.jl icon indicating copy to clipboard operation
Metatheory.jl copied to clipboard

Can not infer term type in egraph

Open vitrun opened this issue 2 years ago • 1 comments

Something is wrong when matching against pure literal rules, such as pi + 3 --> cos(4), im+(pi+3) --> sin(4). The following demo can reproduce the issue. I've tried different versions including master and v1.3.3.

using Metatheory
using Metatheory.EGraphs
using TermInterface

struct Term{T}
  f::Any
  args::Vector{Any}
end

function Term(f, args)
  T = if length(args) == 0
    Any
  elseif length(args) == 1
    promote_type(symtype(args[1]))
  else
    promote_type(symtype(args[1]), symtype(args[2]))
  end
  Term{T}(f, args)
end

Base.promote_type(::Type{Irrational{:π}}, ::Type{Int64}) = Real


TermInterface.exprhead(e::Term) = :call
TermInterface.operation(e::Term) = e.f
TermInterface.arguments(e::Term) = e.args
TermInterface.istree(e::Term) = true
TermInterface.symtype(::Term{T}) where {T} = T
TermInterface.symtype(::T) where {T} = T

function TermInterface.similarterm(x::Term, head, args; metadata = nothing, exprhead = :call)
  Term(head, args)
end

function EGraphs.egraph_reconstruct_expression(
  T::Type{Term{S}},
  op,
  args;
  metadata = nothing,
  exprhead = nothing,
) where {S}
  Term(op, args)
end

pt = @theory a b c  begin
  im + (pi + 3) --> sin(4)
  # pi + 3 --> cos(4)
end


# let's create an egraph 
ex = Term(+, [im, Term(+, [pi, 3])])
g = EGraph(ex)

settermtype!(g, Term{symtype(ex)})
# settermtype!(g, :+, 2, Term{Real})

saturate!(g, pt)
r = extract!(g, astsize)
println(r)

I digged into the code and found following function in ematch.jl, which I believe is to blame.

function lookup_pat(g::EGraph, p::PatTerm)
  @assert isground(p)

  eh = exprhead(p)
  op = operation(p)
  args = arguments(p)
  ar = arity(p)

  T = gettermtype(g, op, ar)

  ids = [lookup_pat(g, pp) for pp in args]
  if all(i -> i isa EClassId, ids)
    n = ENodeTerm{T}(eh, op isa Symbol ? eval(op) : op, ids)
    ec = lookup(g, n)
    mn = ENodeTerm{T}(eh, +, [1, 2])
    ec2 = lookup(g, mn)
    println("T: $T, op: $(typeof(op)), n: $n, ec: $ec, mn:$mn, ec2: $ec2")
    return ec
  else
    return nothing
  end
end

In the demo above, overloaded promote_type is used to decide the type parameter of Term{T}, and egraph has no idea of it. Meanwhile, is T = gettermtype(g, op, ar) sufficient to decide the type of an enode term? I doubt that. Consider pi + 3 and im +3, they have same op + and ar 2, but the resulting termtypes are totally different.

vitrun avatar May 26 '22 15:05 vitrun

Hi @0x0f0f0f , any idea about this issue?

vitrun avatar May 30 '22 02:05 vitrun