MXFusion
MXFusion copied to clipboard
Examples/benchmarking
Description of changes:
This is some benchmarking of Bayesian Neural Networks (meanfield VI) against a non-Bayesian NN. Hopefully this could provide a useful starting point for further analysis (e.g. different kinds of BNN).
The script examples/benchmarking/bnn_classification_benchmark.py
runs through several datasets (MNIST, FashionMNIST, CIFAR10, CIFAR100), with 3 different NN architectures. Several metrics are computed (Accuracy, MSE (=Brier score), Log loss). Some "sensible" defaults are set for the hyperparameters - no HP tuning is performed. Results are stored in the results.txt
file as a list of json strings.
Also added a notebook in the notebooks directory for exploring the results. This outputs figures to the directory examples/benchmarking/figs
(figures also included).
Changes to MXFusion core files:
-
mxfusion/components/functions/mxfusion_gluon_function.py
: Made the exception more helpful -
mxfusion/inference/batch_loop.py
: added a callback for custom status messages -
mxfusion/inference/grad_based_inference.py
: addedGradIteratorBasedInference
- a version ofGradBasedInference
that operates on a data loader -
mxfusion/inference/minibatch_loop.py
: fixed bug that stopped it working on GPUs; added callback for custom status messages
By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
Looks cool Tom, haven't had a chance to actually go through what the results look like yet but the changes to the core MXFusion codebase look fine to me.
@meissnereric can you have a look at the failing tests? Don't think this was happening before.
I think this was happening before, I remember seeing it.
The reason is that you're using Python 3.6 only string formatting in places. this "(f"Context device id {ctx.device_id} outside range of list {ctx_list} or None")" style isn't supported in 3.4/3.5, use the classic "blah".format() style. Shouldn't be a big change, thanks Tom!
Codecov Report
Merging #143 into develop will decrease coverage by
0.41%
. The diff coverage is30.43%
.
@@ Coverage Diff @@
## develop #143 +/- ##
===========================================
- Coverage 85.19% 84.78% -0.42%
===========================================
Files 78 78
Lines 3850 3917 +67
Branches 654 666 +12
===========================================
+ Hits 3280 3321 +41
- Misses 376 395 +19
- Partials 194 201 +7
Impacted Files | Coverage Δ | |
---|---|---|
...on/components/functions/mxfusion_gluon_function.py | 86.9% <0%> (ø) |
:arrow_up: |
mxfusion/inference/batch_loop.py | 80% <0%> (-20%) |
:arrow_down: |
mxfusion/inference/__init__.py | 100% <100%> (ø) |
:arrow_up: |
mxfusion/inference/grad_based_inference.py | 71.42% <33.33%> (-19.88%) |
:arrow_down: |
mxfusion/inference/minibatch_loop.py | 73.33% <40%> (-4.45%) |
:arrow_down: |
mxfusion/inference/inference_parameters.py | 84.4% <0%> (-4.49%) |
:arrow_down: |
mxfusion/models/factor_graph.py | 84.72% <0%> (-0.17%) |
:arrow_down: |
mxfusion/util/graph_serialization.py | ||
mxfusion/util/serialization.py | 85.71% <0%> (ø) |
|
mxfusion/inference/inference.py | 83.33% <0%> (+1.51%) |
:arrow_up: |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update 96ccde7...01fd48c. Read the comment docs.