WIP: add support for AffineAsVariable
Closes #214
Following up from our discussion at JuMP-dev. I have some questions, mainly around how this would be used.
Would you want to use this an explicit layer, or as something like
y, _ = MathOptAI.add_predictor(model, chain, x;
config = Dict(Flux.Dense => MathOptAI.AffineAsVariable),
)
How does this train with only a single input-output realisation? Maybe I need to read the 20 page paper...
Or is the idea that you'd have a small AffineAsVariable layer, that was then applied for many time-steps?
For dynamic optimization, commonly we would have one network used over many time steps. And for training, we would pose it as a parameter estimation problem with multiple dynamic trajectory datasets. So we would need to specify a NN over a set of inputs/outputs where the output variables are indexed over the set, but the weight variables are the same across all the inputs.
Hence, for syntax we would need to be able to add the predictor multiple times for different inputs, but would need each one to use the same variables for A and b.
For reference, Gekko provides an API for training NNs via 2nd order solvers (e.g., Ipopt) in the Gekko AML: https://gekko.readthedocs.io/en/latest/brain.html
Okay, I need to think about this a bit more.
Codecov Report
:white_check_mark: All modified and coverable lines are covered by tests.
:white_check_mark: Project coverage is 100.00%. Comparing base (2028b5c) to head (3f4076b).
:warning: Report is 1 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #215 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 28 29 +1
Lines 742 768 +26
=========================================
+ Hits 742 768 +26
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.