SymbolicRegression.jl icon indicating copy to clipboard operation
SymbolicRegression.jl copied to clipboard

BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

Open MilesCranmer opened this issue 1 year ago • 18 comments

These new experimental Expression types store both the operators and variable names within the object, rather than the plain Node which only stores the enum information about an expression.

This also adds ParametricExpression to learn basis expressions that have variable constants depending on class:

using SymbolicRegression
using Random: MersenneTwister
using Zygote
using MLJBase: machine, fit!, predict

rng = MersenneTwister(0)
X = NamedTuple{(:x1, :x2, :x3, :x4, :x5)}(ntuple(_ -> randn(rng, Float32, 30), Val(5)))
X = (; X..., classes=rand(rng, 1:2, 30))
p1 = rand(rng, Float32, 2)
p2 = rand(rng, Float32, 2)

y = [
    2 * cos(X.x4[i] + p1[X.classes[i]]) + X.x1[i]^2 - p2[X.classes[i]] for
    i in eachindex(X.classes)
]

model = SRRegressor(;
    niterations=10,
    binary_operators=[+, *, /, -],
    unary_operators=[cos, exp],
    populations=10,
    expression_type=ParametricExpression,  # Subtype of `AbstractExpression`
    expression_options=(; max_parameters=2),
    autodiff_backend=:Zygote,
    parallelism=:multithreading,
)

mach = machine(model, X, y)
fit!(mach)
ypred = predict(mach, X)

so it basically learns $y= 2 \cos(x_4 + \alpha) + x_1^2 - \beta$ for $\alpha$ and $\beta$ parameters (which can be different according to the classes parameter – here there are two classes/types of behavior).

This ParametricExpression is just a single implementation of AbstractExpression but you can see how you can do pretty custom things now.


TODO:

  • [x] Allow passing a class feature to MLJ which will have special treatment.
  • [x] Debug why some of the tests seem to get stuck and take 3x longer to finish than normal.
  • [x] Consider documenting this, or just leaving it as an experimental undocumented feature until it stabilizes.
  • [x] Add Enzyme backend.
  • [ ] Add example to docs.
  • [x] Consider moving to Literate.jl for docs?
  • [x] Fix ResourceMonitor weirdness

MilesCranmer avatar Jun 24 '24 03:06 MilesCranmer

Benchmark Results

master e2b369ea768ae9... master/e2b369ea768ae9...
search/multithreading 18.7 ± 0.27 s 21.4 ± 0.58 s 0.872
search/serial 30.6 ± 0.38 s 31.5 ± 0.51 s 0.971
utils/best_of_sample 0.782 ± 0.29 μs 1 ± 0.33 μs 0.78
utils/check_constraints_x10 11.9 ± 3.2 μs 11.7 ± 3 μs 1.01
utils/compute_complexity_x10/Float64 2.1 ± 0.11 μs 2.07 ± 0.1 μs 1.01
utils/compute_complexity_x10/Int64 2.06 ± 0.1 μs 2.06 ± 0.11 μs 1
utils/compute_complexity_x10/nothing 1.41 ± 0.13 μs 1.46 ± 0.11 μs 0.965
utils/optimize_constants_x10 31 ± 7 ms 30.7 ± 6.6 ms 1.01
time_to_load 0.966 ± 0.024 s 0.992 ± 0.009 s 0.974

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

github-actions[bot] avatar Jun 24 '24 04:06 github-actions[bot]

Pull Request Test Coverage Report for Build 9639805727

Details

  • 239 of 250 (95.6%) changed or added relevant lines in 15 files are covered.
  • 4 unchanged lines in 3 files lost coverage.
  • Overall coverage decreased (-0.1%) to 94.475%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/ConstantOptimization.jl 20 21 95.24%
src/HallOfFame.jl 16 17 94.12%
src/PopMember.jl 12 13 92.31%
src/ExpressionBuilder.jl 82 84 97.62%
src/Mutate.jl 9 11 81.82%
src/MutationFunctions.jl 51 55 92.73%
<!-- Total: 239 250
Files with Coverage Reduction New Missed Lines %
src/SingleIteration.jl 1 98.46%
src/Mutate.jl 1 88.02%
src/InterfaceDynamicExpressions.jl 2 89.04%
<!-- Total: 4
Totals Coverage Status
Change from base Build 9535593377: -0.1%
Covered Lines: 2548
Relevant Lines: 2697

💛 - Coveralls

coveralls avatar Jun 24 '24 06:06 coveralls

Pull Request Test Coverage Report for Build 9686354911

Details

  • 275 of 296 (92.91%) changed or added relevant lines in 17 files are covered.
  • 34 unchanged lines in 5 files lost coverage.
  • Overall coverage decreased (-1.4%) to 93.22%

Changes Missing Coverage Covered Lines Changed/Added Lines %
ext/SymbolicRegressionSymbolicUtilsExt.jl 3 4 75.0%
src/ConstantOptimization.jl 20 21 95.24%
src/HallOfFame.jl 16 17 94.12%
src/PopMember.jl 12 13 92.31%
src/ExpressionBuilder.jl 83 85 97.65%
src/Mutate.jl 10 12 83.33%
src/MLJInterface.jl 15 18 83.33%
src/MutationFunctions.jl 51 55 92.73%
src/InterfaceDynamicExpressions.jl 19 25 76.0%
<!-- Total: 275 296
Files with Coverage Reduction New Missed Lines %
src/SingleIteration.jl 1 98.46%
src/SymbolicRegression.jl 1 94.83%
src/Options.jl 2 94.89%
ext/SymbolicRegressionSymbolicUtilsExt.jl 7 50.0%
src/MLJInterface.jl 23 83.33%
<!-- Total: 34
Totals Coverage Status
Change from base Build 9535593377: -1.4%
Covered Lines: 2530
Relevant Lines: 2714

💛 - Coveralls

coveralls avatar Jun 27 '24 09:06 coveralls

I ran the benchmark on my machine and don't see any performance differences. Not sure what's up with the github action version...

MilesCranmer avatar Jun 27 '24 21:06 MilesCranmer

Pull Request Test Coverage Report for Build 9704727222

Details

  • 301 of 307 (98.05%) changed or added relevant lines in 17 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage increased (+1.3%) to 95.922%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/ConstantOptimization.jl 20 21 95.24%
src/HallOfFame.jl 16 17 94.12%
src/InterfaceDynamicExpressions.jl 23 24 95.83%
src/PopMember.jl 12 13 92.31%
src/MLJInterface.jl 23 25 92.0%
<!-- Total: 301 307
Files with Coverage Reduction New Missed Lines %
src/SingleIteration.jl 1 98.46%
<!-- Total: 1
Totals Coverage Status
Change from base Build 9535593377: 1.3%
Covered Lines: 2611
Relevant Lines: 2722

💛 - Coveralls

coveralls avatar Jun 28 '24 05:06 coveralls

Weirdly it seems like the gradients are never actually being used in the Zygote optimisation... They are always nothing and I didn't notice it. I guess Optim.jl automatically switches to finite difference if the gradient errors.

MilesCranmer avatar Jun 28 '24 05:06 MilesCranmer

@gdalle would you happen to know how to force custom ChainRules rrule/frule onto other backends? I think part of the reason they may not work is because eval_tree_array is too complicated for them to trace. But if they just use the rrule/frule they should be good.

MilesCranmer avatar Jun 28 '24 06:06 MilesCranmer

They are always nothing and I didn't notice it. I guess Optim.jl automatically switches to finite difference if the gradient errors.

I would be very surprised if it did that. This kind of trick is familiar within the SciML ecosystem but (hopefully) not widespread outside ^^

@gdalle would you happen to know how to force custom ChainRules rrule/frule onto other backends?

What do you mean by that? Which backends, and what do you want to "force"? Happy to jump on a call if that's easier.

gdalle avatar Jun 28 '24 06:06 gdalle

I remember reading this section on ChainRules.jl: https://juliadiff.org/ChainRulesCore.jl/dev/index.html#Packages-supporting-importing-rules-from-ChainRules.

Several packages do not automatically load rules from ChainRules by default, but support importing rules that were defined using it, e.g. with a macro.

I guess I am wondering if there is a multi-backend way to trigger this "import" to happen, perhaps from within DifferentiationInterface.jl.

I have an rrule defined here: https://github.com/SymbolicML/DynamicExpressions.jl/blob/9e95f0538207c360874615686be5c0ec627aee42/src/ChainRules.jl#L22 and am looking to import it across different backends. (Since otherwise they aren't compatible due to the mutation happening in eval_tree_array)

MilesCranmer avatar Jun 28 '24 06:06 MilesCranmer

No there's nothing like that at the moment. The reverse functionality exists though, as DifferentiateWith (define a chain rule from another backend).

Which backends are you hoping to use other than Zygote.jl?

gdalle avatar Jun 28 '24 06:06 gdalle

No there's nothing like that at the moment. The reverse functionality exists though, as DifferentiateWith (define a chain rule from another backend).

Cool!

Which backends are you hoping to use other than Zygote.jl?

Enzyme.jl and ForwardDiff.jl (ForwardDiff seems to work, but I'd like it to use my frule as its faster than the default traced version)

MilesCranmer avatar Jun 28 '24 06:06 MilesCranmer

Enzyme.jl and ForwardDiff.jl (ForwardDiff seems to work, but I'd like it to use my frule as its faster than the default traced version)

Note that neither of these backends is mentioned in https://juliadiff.org/ChainRulesCore.jl/dev/index.html#Packages-that-automatically-load-rules-from-ChainRules, hence my question. As I explain here, ForwardDiff and Enzyme both have their own sets of rules.

  • To import a ChainRule into ForwardDiff, use https://github.com/ThummeTo/ForwardDiffChainRules.jl
  • To import a ChainRule into Enzyme, use these macros:
    • https://enzymead.github.io/Enzyme.jl/stable/api/#Enzyme.@import_rrule-Tuple
    • https://enzymead.github.io/Enzyme.jl/stable/api/#Enzyme.@import_frule-Tuple

gdalle avatar Jun 28 '24 06:06 gdalle

I think I am not following (maybe I need coffee) because I see both ForwardDiff and Enzyme right here –

Screenshot 2024-06-28 at 08 49 36

It just needs something to automatically import them. (Which maybe DifferentiationInterface might expose an API for in the future so I can avoid writing an extension for each autodiff backend myself :))

MilesCranmer avatar Jun 28 '24 07:06 MilesCranmer

Sorry, I was looking at the section above, packages that load automatically (your link from this comment was dead so I scrolled down from the top). My mistake. You're right, both ForwardDiff and Enzyme support manual loading, and the tools necessary are the ones I linked to. I will have to think about whether DI should support these translation utilities. But I think it's really hard cause each rule system has its own syntax, and metaprogramming makes the whole thing even harder. So feel free to open an issue in DI but it might be low on my list of priorities.

gdalle avatar Jun 28 '24 08:06 gdalle

Basically, supporting rule translation in DI would involve defining a universal rule definition syntax. The universal differentiation syntax was already tough to iron out, but the universal rule definition syntax is another level entirely.

gdalle avatar Jun 28 '24 08:06 gdalle

Pull Request Test Coverage Report for Build 9763114573

Details

  • 352 of 357 (98.6%) changed or added relevant lines in 19 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+1.4%) to 96.083%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/Dataset.jl 9 10 90.0%
src/HallOfFame.jl 16 17 94.12%
src/PopMember.jl 12 13 92.31%
src/MLJInterface.jl 23 25 92.0%
<!-- Total: 352 357
Totals Coverage Status
Change from base Build 9725050445: 1.4%
Covered Lines: 2625
Relevant Lines: 2732

💛 - Coveralls

coveralls avatar Jul 02 '24 16:07 coveralls

I think the parametric expressions might be getting their metadata (incl. parameters) stripped when a model is updated?

MilesCranmer avatar Aug 02 '24 17:08 MilesCranmer

Pull Request Test Coverage Report for Build 11204590927

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 466 of 482 (96.68%) changed or added relevant lines in 24 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage increased (+1.2%) to 95.808%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/Dataset.jl 9 10 90.0%
src/ExpressionBuilder.jl 98 99 98.99%
src/HallOfFame.jl 16 17 94.12%
src/PopMember.jl 12 13 92.31%
src/Utils.jl 25 26 96.15%
src/MLJInterface.jl 23 25 92.0%
src/Operators.jl 7 9 77.78%
ext/SymbolicRegressionEnzymeExt.jl 15 18 83.33%
src/Options.jl 49 53 92.45%
<!-- Total: 466 482
Files with Coverage Reduction New Missed Lines %
src/Configure.jl 1 94.16%
<!-- Total: 1
Totals Coverage Status
Change from base Build 10700134500: 1.2%
Covered Lines: 2651
Relevant Lines: 2767

💛 - Coveralls

coveralls avatar Aug 13 '24 00:08 coveralls

I was encountering some issues with constraint parsing in Options.jl. Check out the comment. Don't know why the test cases don't catch the issue.

atharvas avatar Sep 12 '24 08:09 atharvas

Going to punt StructuredExpressions until later. @eelregit let me know if you are at all interested in this! StructuredExpression would let you evolve within a fixed functional form. Seems like there are a couple missing methods that would allow it to work but hopefully won't take too much work. I'll have to pause on this side of things for now.

MilesCranmer avatar Oct 06 '24 14:10 MilesCranmer

Seems like the garbage collection is going crazy in the tests, which is why they are so slow. The reason why 1.6 and 1.8 are much faster is – I think – because DispatchDoctor.jl is turned off. So something about DispatchDoctor.jl is causing the GC to overwork itself... Possibly related to https://github.com/MilesCranmer/DispatchDoctor.jl/issues/57 and https://github.com/MilesCranmer/DispatchDoctor.jl/issues/58?

MilesCranmer avatar Oct 06 '24 15:10 MilesCranmer

Fixed the performance regression in the unittests with https://github.com/SymbolicML/DynamicExpressions.jl/pull/102/commits/74c8dc1db5a0b192c49a35d7a8b1d8e5a792cd61.

Edit: still seems to hang around a bit. It's something to do with DispatchDoctor for sure, from studying the PProf outputs. So it won't affect actual runtime performance, just the testing. So probably fine to merge for now.

MilesCranmer avatar Oct 06 '24 20:10 MilesCranmer