StatsModels.jl icon indicating copy to clipboard operation
StatsModels.jl copied to clipboard

Question: Possible to do `y ~ 1 + poly(x,deg=5)`

Open behinger opened this issue 1 year ago • 2 comments

Hi! hope I didn't miss it in the docs somewhere, is it possible to extend the formula syntax with keyword-value pairs? E.g. poly(x,deg=5) instead of the boring poly(x,5)

Cheers, Benedikt

behinger avatar Aug 11 '22 12:08 behinger

It's not currently possible but

diff --git a/src/formula.jl b/src/formula.jl
index 6cab6e1..6d2ef57 100644
--- a/src/formula.jl
+++ b/src/formula.jl
@@ -188,13 +188,19 @@ function parse!(ex::Expr, rewrites::Vector)
 
     # parse a copy of non-special calls
     ex_parsed = ex.args[1] ∉ SPECIALS ? deepcopy(ex) : ex
-    
+
     # iterate over children, checking for special rules
     child_idx = 2
     while child_idx <= length(ex_parsed.args)
-        @debug "  ($(ex_parsed.args[1])) i=$child_idx: $(ex_parsed.args[child_idx])"
+        child = ex_parsed.args[child_idx]
+        @debug "  ($(ex_parsed.args[1])) i=$child_idx: $child"
+        if Meta.isexpr(child, :parameters) || Meta.isexpr(child, :kw)
+            @debug "  not descending into keywords"
+            child_idx += 1
+            continue
+        end
         # depth first: parse each child first
-        parse!(ex_parsed.args[child_idx], rewrites)
+        parse!(child, rewrites)
         # find first rewrite rule that applies
         rule = filterfirst(r->applies(ex_parsed, child_idx, r), rewrites)
         # re-write according to that rule and update the child to position rewrite nex_parsedt
@@ -228,7 +234,7 @@ function capture_call_ex!(ex::Expr, ex_parsed::Expr)
                f_anon_ex,
                tuple(symbols...),
                Meta.quot(deepcopy(ex)),
-               :[$(ex_parsed.args[2:end]...)]]
+               :[$(map(e -> e isa Expr ? Meta.quot(e) : e, ex_parsed.args[2:end])...)]]
     return ex
 end
 
@@ -244,6 +250,10 @@ function terms!(ex::Expr)
     elseif is_call(ex, :capture_call)
         # final argument of capture_call holds parsed terms
         ex.args[end].args .= terms!.(ex.args[end].args)
+    elseif Meta.isexpr(ex, :parameters)
+        ex.args[2:end] .= terms!.(ex.args[2:end])
+    elseif Meta.isexpr(ex, :kw)
+        ex.args[2] = terms!(ex.args[2])
     end
     return ex
 end
diff --git a/src/terms.jl b/src/terms.jl
index b49a902..4b5df99 100644
--- a/src/terms.jl
+++ b/src/terms.jl
@@ -327,8 +327,19 @@ capture_call(args...) = FunctionTerm(args...)
 
 extract_symbols(x) = Symbol[]
 extract_symbols(x::Symbol) = [x]
-extract_symbols(ex::Expr) =
-    is_call(ex) ? mapreduce(extract_symbols, union, ex.args[2:end]) : Symbol[]
+function extract_symbols(ex::Expr)
+    if is_call(ex)
+        mapreduce(extract_symbols, union, ex.args[2:end])
+    elseif Meta.isexpr(ex, :parameters, 1)  # `f(; x)`
+        extract_symbols(ex.args[1])
+    elseif Meta.isexpr(ex, :parameters)  # `f(; x=a)`
+        mapreduce(extract_symbols, union, ex.args[1].args)
+    elseif Meta.isexpr(ex, :kw)  # `f(x=a)` or `f(; x=a)`
+        extract_symbols(last(ex.args))
+    else
+        Symbol[]
+    end
+end
 
 ################################################################################
 # showing terms

will get you part of the way there if you wanted to have a go at implementing it.

ararslan avatar Aug 11 '22 17:08 ararslan