FMPE with higher-dim conditions
Fixes #1702
This PR addresses the bug for FMPE posterior sampling when the condition event dim is greater than 1. I also added a test which checks that the samples are the correct shape, similar to posterior_nn_test.py::test_batched_score_sample_with_different_x, so there is some code repetition, but I think our minimal setup looks quite different when we are dealing with matrix variables so I think this is clearer than trying to fit it in the same test function.
Codecov Report
:white_check_mark: All modified and coverable lines are covered by tests.
:warning: Please upload report for BASE (main@2c216c2). Learn more about missing BASE report.
:white_check_mark: All tests successful. No failed tests found.
Additional details and impacted files
@@ Coverage Diff @@
## main #1704 +/- ##
=======================================
Coverage ? 84.65%
=======================================
Files ? 137
Lines ? 11489
Branches ? 0
=======================================
Hits ? 9726
Misses ? 1763
Partials ? 0
| Flag | Coverage Δ | |
|---|---|---|
| unittests | 84.65% <100.00%> (?) |
Flags with carried forward coverage won't be shown. Click here to find out more.
| Files with missing lines | Coverage Δ | |
|---|---|---|
| sbi/inference/potentials/vector_field_potential.py | 74.56% <100.00%> (ø) |
|
| sbi/samplers/ode_solvers/zuko_ode.py | 100.00% <100.00%> (ø) |