sbi
sbi copied to clipboard
Refactor z-scoring tests to fix shapes and enable MNLE support
Refactor z-scoring tests to fix shapes and enable MNLE support
Description
This PR refactors test_z_scoring_structured in sbi/tests/sbiutils_test.py to address Issue #1559.
It resolves variable naming confusion, removes code duplication, fixes dimensionality bugs, and enables testing for the mnle model which was previously skipped.
Changes
- Variable Renaming: Renamed
buildertobuild_fnandnettoestimatorto clearly distinguish between the factory function and the instantiated network. - Deduplication: Consolidated separate loops for different builders into a single parameterized test logic with a shared model list.
- MNLE Support: Added logic to generate mixed data (continuous + discrete) specifically for
mnle, satisfying its input requirements and preventing crashes. - Dimensionality Fixes:
- Set
num_dim = 3for boththetaandx. This ensuresmnlehas >1 continuous dimension (preventingNaNin std calculations) and allowstransform_to_unconstrainedto work via implicit bound checking without passing an explicitpriorargument.
- Set
- Shape Correction: Added
.unsqueeze(0)to flow inputs to satisfy the(sample, batch, dim)shape requirement, fixingAssertionErrorfailures. - Argument Safety: Removed the explicit
priorkwarg from flow builders to preventTypeErrorin Zuko flows, while ensuring bounds checks still pass via thex_distfallback.
Related Issue
- Closes #1559
Codecov Report
:white_check_mark: All modified and coverable lines are covered by tests.
:warning: Please upload report for BASE (main@3316888). Learn more about missing BASE report.
Additional details and impacted files
@@ Coverage Diff @@
## main #1711 +/- ##
=======================================
Coverage ? 84.68%
=======================================
Files ? 137
Lines ? 11493
Branches ? 0
=======================================
Hits ? 9733
Misses ? 1760
Partials ? 0
| Flag | Coverage Δ | |
|---|---|---|
| unittests | 84.68% <ø> (?) |
Flags with carried forward coverage won't be shown. Click here to find out more.