MethodOfLines.jl icon indicating copy to clipboard operation
MethodOfLines.jl copied to clipboard

Move to using SymbolicUtils chains

Open xtalax opened this issue 1 year ago • 0 comments

Describe the bug 🐞

We aren't catching derivatives within more complex equational structures Expected behavior

These should just work, the reason is we are doing rudimentary custom recursive term splitting rather than using the excellent tools for this in SymbolicUtils.

See the comments in this this rulegen:

@inline function generate_winding_rules(II::CartesianIndex, s::DiscreteSpace, depvars, derivweights::DifferentialDiscretizer, bcmap, indexmap, terms; skip = [])
    wind_ufunc(v, I, x) = s.discvars[v][I]
    # for all independent variables and dependant variables
    rules = safe_vcat(#Catch multiplication
        reduce(safe_vcat, [reduce(safe_vcat, [[@rule *(~~a, $(Differential(x)^d)(u), ~~b) => upwind_difference(*(~a..., ~b...), d, Idx(II, s, u, indexmap), s, filter_interfaces(bcmap[operation(u)][x]), depvars, derivweights, (x2i(s, u, x), x), u, wind_ufunc, indexmap) for d in (
            let orders = derivweights.orders[x]
                setdiff(orders[isodd.(orders)], skip)
            end
        )] for x in ivs(u, s)], init = []) for u in depvars], init = []),

        #Catch division and multiplication, see issue #1
        reduce(safe_vcat, [reduce(safe_vcat, [[@rule /(*(~~a, $(Differential(x)^d)(u), ~~b), ~c) => upwind_difference(*(~a..., ~b...) / ~c, d, Idx(II, s, u, indexmap), s, filter_interfaces(bcmap[operation(u)][x]), depvars, derivweights, (x2i(s, u, x), x), u, wind_ufunc, indexmap) for d in (
            let orders = derivweights.orders[x]
                setdiff(orders[isodd.(orders)], skip)
            end
        )] for x in ivs(u, s)], init = []) for u in depvars], init = [])
    )

#### This is redundant, we just want to return the rules as a vector
    for t in terms
        for r in rules
            if r(t) !== nothing
                push!(wind_rules, t => r(t))
            end
        end
    end
## Here we are just generating some extra pairs, instead we should add @rule to all pure pair constructs
    return safe_vcat(wind_rules, vec(mapreduce(safe_vcat, depvars, init = []) do u
        mapreduce(safe_vcat, ivs(u, s), init = []) do x
            j = x2i(s, u, x)
            let orders = setdiff(derivweights.orders[x], skip)
                oddorders = orders[isodd.(orders)]
                # for all odd orders
                if length(oddorders) > 0
                    map(oddorders) do d
                        (Differential(x)^d)(u) => upwind_difference(d, Idx(II, s, u, indexmap), s, filter_interfaces(bcmap[operation(u)][x]), derivweights, (j, x), u, wind_ufunc, true)
                    end
                else
                    []
                end
            end
        end
    end))
end

What we want is to take all the derivative and integral rules in their current order, including the boundaryvalmaps and instead call them with a derivative rewriter chain. We wrap the rules in a deriv_rewriter = SymbolicUtils.Prewalk(SymbolicUtils.Chain(deriv_rules)). We call this rewriter with the equation's lhs/rhs, and only after that substitute the varmaps as usual. This will recursively apply the rules, top to bottom, left to right.

We can get rid of split_terms and the term argument to the rulegens.

xtalax avatar Jan 26 '24 17:01 xtalax