FunctionTerm is dead, long live FunctionTerm
This is a pretty substantial change in how non-special function calls are
represented. Instead of generating an anonymous function, the new
FunctionTerm simple wraps the called function, it's arguments (wrapped as
Terms), and the original expression. When evaluated with modelcols, it uses
Base.Broadcast.broadcasted to lazily fuse nested function calls and then at
the top level calls materialize to run. I hope that this will provide
performance that's comparable to the anonymous function, while being both more
run-time friendly and much simpler.
As a side effect, all of the macro-time parsing has been removed, and all the
special syntax now applies at run time (via method overloading of +, &, ~,
and * (I know, I'm sorry, this is just for continuity's sake right now).
Also, ~~shamelessly stealing from~~ inspired by @oxinabox suggestions for how to
handle nested special and non-special functions (#117), I've added some
additional special syntax: protect and unprotect. protect says to treat
every call below it as non-special, while unprotect (in an otherwise protected
context) says to treat everything below it as potentially special (e.g., as if
it occurred at the top level of a formula). This isn't perfect at the moment
but it's a long way towards being able to do something like 1 - unprotect(poly(x, 3))
and having it do something sensible (actually that might work; what doesn't work
is (1 - unprotect(a*b)) since that will generated a + b + a&b which don't
get fused into a single matrix which messes with the broadcasting).
I also added on a whim an @unprotect op macro which will designate a function as
special in unprotected contexts. This simply adds a method like
apply_schema(t::FunctionTerm{typeof($op)}, sch::Schema, Mod::Type) =
apply_schema($op(t.args...), sch, Mod)
So, if you have something like FunctionTerm{typeof(+)} in a non-protected
context, then it will be converted into a call to +(args[1], args[2], ...)
when you do apply_schema. This is taking advantage of the fact that all the
special handling of special syntax is done via method overloading of the
corresponding functions, so it should work fine at runtime. At the moment it's
basically only useful for internal purposes but a two-argument form (function
and context type) might be useful for packages.
Tests do pass with this change, but it probably needs more tests to specifically cover the protect/unprotect stuff. Docs aren't updated because that seems premature at this point before folks have had a chance to weigh in.
The major problem with this proposal, now that I've messed with it a bit, is that these kinds of methods really hit the compiler hard:
# + concatenates terms
Base.:+(terms::AbstractTerm...) = (unique(reduce(+, terms))..., )
Base.:+(a::AbstractTerm) = a
Base.:+(a::AbstractTerm, b::AbstractTerm) = (a,b)
# associative rule for +
Base.:+(as::TupleTerm, b::AbstractTerm) = (as..., b)
Base.:+(a::AbstractTerm, bs::TupleTerm) = (a, bs...)
Base.:+(as::TupleTerm, bs::TupleTerm) = (as..., bs...)
(TupleTerm is just NTuple{AbstractTerm, N} where {N}).
The reason is that a new method has to be compiled for every combination of number and types of terms being added together. This is related to #165 : all the type parameters in the term types create a lot of compiler overhead. I can think of a few ways around this.
-
we could go back to doing these transformations at parse time, which would require that
FunctionTermcontinue to hold onto two copies of the args (one parsed and the other not). It would also require some duplication of functionality if we want to continue to support run-time construction as a first class interface, since we still need all these methods for the same syntax to apply as inside the macro. -
we could use some other non-parametrized version of the +-ed terms, like a Vector, or a custom wrapper. This seems potentially the least disruptive (and could even convert back to a tuple if needed for performance at some point).
-
we could represnt ALL calls, even to special syntax like
+,&, and*, asFunctionTerms, and only parse them atapply_schematime, or do some fanciness during construction. I think this is the most radical/involved and I'm not sure it's really worth it.
Edit: I did try wrapping all these in a @nospecialize block but it didn't seem to have any effect, but then again I don't really know what I'm doing so ..
Now I'm really confused. I've actually gone through the exercise of replacing the tuple-based representation for +ed terms with a Vector-based storage (in this branch), and STILL get much longer first-call timings for every new number of terms, which suggests that some new methods are getting compiled, even though all the types are the same (no longer depend on the number of terms).
Base.:+(terms::AbstractTerm...) is still going to be specialized on the number of arguments, right?
Yeah, I'm afraid so. Makes me wonder how stuff like this is handled in Base...
edit, it's easy enough to find out :)
https://github.com/JuliaLang/julia/blob/24f033c9517cd186448acc3bededa1a64eaad09f/base/operators.jl#L521-L543
So I think defining that method is actually not necessary, except that we also want to call unique on the output...
Specialisation around splatting is weird in general, and its not documented what it will do.
IIRC foo(x::VarArg) and foo(x...) specialize differently.
But in anycase if you don't want specialiation just used the @nospecialise macro.
Which is either done for a secition of code via:
@nospecialize
foo(x)=1
bar(x)=2
@specialize
or for an argument:
foobar(x, @nospecialize(y)) = x + y
if @nospecialize isn't working, then that would be a bug in Base, and you should open an issue.
Rather than redo the whole design to workaround the fact that you can't turn off specialization.
I have had @nospecialize work before.
I'm more and more concerned that just putting a few @nospecialize around the base methods we're overloading isn't going to be enough, since every time you hit a constructor for any term that can have children (formula, interaction, matrix, function) you're going to trigger compilation for every number and type of children, and same thing for every other method involving those types. At that point you're going to have to put @nospecialize on just about everything that can take an AbstractTerm which seems like a sign that something has gone wrong.
At this point, my plan is to do a bit of benchmarking to see whether formula creation gets appreciably WORSE with this PR, and if not, we can forget about trying to fix the performance in another PR and focus on whether these changes do justice to @oxinabox vision enough to merge.
Okay I've done a bit of timing (using the script from https://gist.github.com/kleinschmidt/f51d305d56a590030c4f8688cbf18929). On master, these are the timings (first and second run times, since I'm mostly interested in compilation time itself here):
y ~ 1 + x
0.000033 seconds (24 allocations: 1.453 KiB)
0.000032 seconds (24 allocations: 1.453 KiB)
y ~ a + b
0.008989 seconds (3.57 k allocations: 222.652 KiB)
0.000014 seconds (11 allocations: 720 bytes)
y ~ a + b + c
0.009176 seconds (3.58 k allocations: 223.059 KiB)
0.000017 seconds (12 allocations: 752 bytes)
y ~ a + b + c + d
0.011082 seconds (3.58 k allocations: 222.840 KiB)
0.000017 seconds (13 allocations: 784 bytes)
y ~ a + b + c + d + e
0.009841 seconds (3.58 k allocations: 223.012 KiB)
0.000023 seconds (15 allocations: 896 bytes)
y ~ a + b + c + d + e + f
0.010660 seconds (3.58 k allocations: 223.105 KiB)
0.000029 seconds (16 allocations: 928 bytes)
y ~ log(a)
0.065474 seconds (96.09 k allocations: 5.480 MiB)
0.029364 seconds (9.05 k allocations: 520.255 KiB)
y ~ log(a) + log(b)
0.941035 seconds (1.19 M allocations: 59.954 MiB, 10.87% gc time)
0.746309 seconds (1.02 M allocations: 51.098 MiB, 12.56% gc time)
y ~ log(a) + log(b) + log(c)
0.793438 seconds (1.43 M allocations: 72.068 MiB, 2.55% gc time)
0.647390 seconds (901.92 k allocations: 45.123 MiB, 1.24% gc time)
y ~ log(a) + log(b) + log(c) + log(d)
0.777564 seconds (1.01 M allocations: 50.078 MiB, 1.10% gc time)
0.809190 seconds (1.01 M allocations: 50.081 MiB, 1.07% gc time)
y ~ log(a) + log(b) + log(c) + log(d) + log(e)
0.914824 seconds (1.11 M allocations: 55.022 MiB, 1.15% gc time)
0.918323 seconds (1.11 M allocations: 55.028 MiB, 1.02% gc time)
y ~ log(a) + log(b) + log(c) + log(d) + log(e) + log(f)
1.036401 seconds (1.21 M allocations: 59.980 MiB, 1.71% gc time)
1.017207 seconds (1.21 M allocations: 59.978 MiB, 0.86% gc time)
y ~ exp(a)
0.027674 seconds (9.39 k allocations: 547.880 KiB)
0.030770 seconds (9.05 k allocations: 520.177 KiB)
y ~ exp(a) + exp(b)
0.778712 seconds (1.20 M allocations: 59.997 MiB, 2.29% gc time)
0.684971 seconds (1.02 M allocations: 51.120 MiB, 1.34% gc time)
y ~ exp(a) + exp(b) + exp(c)
0.742065 seconds (1.24 M allocations: 62.374 MiB, 1.24% gc time)
0.655097 seconds (901.92 k allocations: 45.129 MiB, 1.56% gc time)
y ~ exp(a) + exp(b) + exp(c) + exp(d)
0.773946 seconds (1.01 M allocations: 50.081 MiB, 1.32% gc time)
0.778134 seconds (1.01 M allocations: 50.084 MiB, 2.41% gc time)
y ~ exp(a) + exp(b) + exp(c) + exp(d) + exp(e)
0.891833 seconds (1.11 M allocations: 55.043 MiB, 1.04% gc time)
0.873386 seconds (1.11 M allocations: 55.022 MiB, 1.05% gc time)
y ~ exp(a) + exp(b) + exp(c) + exp(d) + exp(e) + exp(f)
0.982716 seconds (1.21 M allocations: 59.970 MiB, 1.11% gc time)
1.017105 seconds (1.21 M allocations: 59.977 MiB, 1.86% gc time)
And on this PR (commit eaec084):
y ~ 1 + x
0.000100 seconds (25 allocations: 1.406 KiB)
0.000093 seconds (25 allocations: 1.406 KiB)
y ~ a + b
0.007148 seconds (2.17 k allocations: 135.030 KiB)
0.000016 seconds (11 allocations: 768 bytes)
y ~ a + b + c
0.042739 seconds (53.78 k allocations: 3.048 MiB)
0.000026 seconds (23 allocations: 1.484 KiB)
y ~ a + b + c + d
0.036280 seconds (53.91 k allocations: 3.052 MiB)
0.000020 seconds (25 allocations: 1.547 KiB)
y ~ a + b + c + d + e
0.047632 seconds (54.05 k allocations: 3.059 MiB)
0.000022 seconds (28 allocations: 1.688 KiB)
y ~ a + b + c + d + e + f
0.039089 seconds (54.18 k allocations: 3.064 MiB)
0.000020 seconds (30 allocations: 1.750 KiB)
y ~ log(a)
0.027760 seconds (44.74 k allocations: 2.699 MiB)
0.000014 seconds (11 allocations: 480 bytes)
y ~ log(a) + log(b)
0.277548 seconds (332.92 k allocations: 17.035 MiB)
0.000028 seconds (27 allocations: 1.703 KiB)
y ~ log(a) + log(b) + log(c)
0.080824 seconds (121.33 k allocations: 6.716 MiB)
0.000031 seconds (42 allocations: 2.906 KiB)
y ~ log(a) + log(b) + log(c) + log(d)
0.113681 seconds (114.74 k allocations: 6.305 MiB, 10.12% gc time)
0.000034 seconds (49 allocations: 3.297 KiB)
y ~ log(a) + log(b) + log(c) + log(d) + log(e)
0.075690 seconds (118.20 k allocations: 6.460 MiB)
0.000034 seconds (57 allocations: 3.828 KiB)
y ~ log(a) + log(b) + log(c) + log(d) + log(e) + log(f)
0.075736 seconds (121.95 k allocations: 6.627 MiB)
0.000037 seconds (64 allocations: 4.219 KiB)
y ~ exp(a)
0.013628 seconds (6.69 k allocations: 411.996 KiB)
0.000013 seconds (11 allocations: 480 bytes)
y ~ exp(a) + exp(b)
0.251576 seconds (332.96 k allocations: 17.036 MiB)
0.000035 seconds (27 allocations: 1.703 KiB)
y ~ exp(a) + exp(b) + exp(c)
0.083678 seconds (121.34 k allocations: 6.716 MiB)
0.000037 seconds (42 allocations: 2.906 KiB)
y ~ exp(a) + exp(b) + exp(c) + exp(d)
0.103966 seconds (114.75 k allocations: 6.304 MiB, 9.53% gc time)
0.000035 seconds (49 allocations: 3.297 KiB)
y ~ exp(a) + exp(b) + exp(c) + exp(d) + exp(e)
0.087459 seconds (118.21 k allocations: 6.461 MiB)
0.000039 seconds (57 allocations: 3.828 KiB)
y ~ exp(a) + exp(b) + exp(c) + exp(d) + exp(e) + exp(f)
0.080457 seconds (121.96 k allocations: 6.629 MiB)
0.000038 seconds (64 allocations: 4.219 KiB)
Bottom line is, this PR is MUCH faster for anything involving a custom function, both first and second run (noticeably so, going from ~1s every time you create a formula with a function call in it to ~100ms first run and <1ms after that). It's slower on first run (~50ms vs. ~10ms) but comparable after that.
There was, I remember, a good reason why it made sense to move the parsing rules
out of the macro and into run-time but at the moment I can't recall exactly
why. Only that it had something to do with the protect/unprotect stuff. I
don't think it was just that I had started to dislike keeping both the
"parsed" and "non-parsed" copies of the arguments in the function term, but that
is a definite bonus in my view.
I reverted the addition of test/Project.toml since that was the only thing breaking 1.0 compatibility, and am now using Compat for only instead of duplicating the code.
I'm having a hard time trying to master the changes in this PR. :-)
Do you know why the new code is faster than the old one despite the cost of compilation? Maybe the compiler is smart enough to avoid compiling everything?
There's two things going on. Before, every FunctionTerm had a type parameter that was the type of the generated anonymous function. A new anonymous function was generated every time, even if the call was the same. So that means that @formula(y ~ exp(a) generates a FunctionTerm{typeof(exp), typeof(SomeAnon)}, and calling it again generates FunctionTerm{typeof(exp), typeof(AnotherAnon)}. That's what's causing (I think) the slow times on master: any method with any argument that involves a function term is recompiled for every instance.
This PR fixes that, by removing the generated anonymous function altogether. So each instance of exp is going to have the same type. However, it also introduces the use of additional methods for +, &, and * operating on Terms to implement the algebraic rules that used to be handled at macro time. This is necessary because of how the protect/unprotect logic (which is one of the main reasons for these changes to function terms) is implemented, as a function call. See these lines: https://github.com/JuliaStats/StatsModels.jl/pull/183/files#diff-54c988da53bc06e26c2bc99378565e55R257-R289 The gist is that when you hit a FunctionTerm{typeof(unprotect)} during apply_schema, that means that you want to re-start applying the special parser rules (associative/distributive rules) below that point in the tree. One way to do that is to implement those rules as methods for &, +, and * (that's these lines: https://github.com/JuliaStats/StatsModels.jl/pull/183/files#diff-13e0a130102f1c75064347b419d09b5fR375-R404), then when you hit a FunctionTerm{typeof(+)} etc. in an unprotected context, you can just do term.f(term.args...), which is also (now) exactly what the @formula macro generates.
Beyond the advantages for simplifying the representation of function terms (by eliminating the need to keep protected and unprotected version of the arguments), this also has the benefit of making the DSL rules simpler (replacing 100-200 lines of code for the expr-based macro-time syntax rules with ~10-20 LOC for the methods) and the formula macro itself really minimal (all it does is wrap symbols as terms, and wrap non-special calls in function terms, which is also a much simpler operation now (this is it now: https://github.com/JuliaStats/StatsModels.jl/pull/183/files#diff-b4d7500de9148b52dc4251c7f5c66082R76-R84).
The downside is that there are more run-time calls which, given the current "everything is a type parameter" design, can be a bit rough on the compiler, since y ~ a + b has a different concrete type than y ~ 1 + a or y ~ a + b + c or y ~ a + b + a&b. Not to mention that +(a, b, c) has a different signature from +(a, b, a&b)`, despite all three arguments being abstract terms. I think these problems are fixable and there are good reasons to fix them anyway, but they do slow down formula construction a bit. That was the point of the benchmarking above. But in any case, I doubt that's the bottleneck in most applications, and it's a minimal price to pay given the massive speed ups we get whenever any function term is involved (as @matthieugomez reported in #164) at least on second run.
Ah indeed, it's logical that not putting an anonymous function in the type parameter improves things.
Another example encountered in the wild that is MASSIVELY improved by this PR:
@formula(y ~ 1 + A * B * (C + C^2 + D) + (1 + A + B | G))
the issue is that before this PR, the * expansion is handled before anything else, so you end up with a ton of terms with C^2, and THEN capture_call is invoked, generating a DIFFERENT anonymous function for every individual term, which really hammers the compiler.
This is ready for review again @nalimilan
I think the failing project level codecov is due to deleting a lot of well-covered code so can be ignored (or rather is better fixed in a separate PR which adds more tests) since the patch coverage is high.