sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Refactor z-scoring tests to fix shapes and enable MNLE support

Open satwiksps opened this issue 2 weeks ago • 1 comments

Refactor z-scoring tests to fix shapes and enable MNLE support

Description

This PR refactors test_z_scoring_structured in sbi/tests/sbiutils_test.py to address Issue #1559. It resolves variable naming confusion, removes code duplication, fixes dimensionality bugs, and enables testing for the mnle model which was previously skipped.

Changes

  • Variable Renaming: Renamed builder to build_fn and net to estimator to clearly distinguish between the factory function and the instantiated network.
  • Deduplication: Consolidated separate loops for different builders into a single parameterized test logic with a shared model list.
  • MNLE Support: Added logic to generate mixed data (continuous + discrete) specifically for mnle, satisfying its input requirements and preventing crashes.
  • Dimensionality Fixes:
    • Set num_dim = 3 for both theta and x. This ensures mnle has >1 continuous dimension (preventing NaN in std calculations) and allows transform_to_unconstrained to work via implicit bound checking without passing an explicit prior argument.
  • Shape Correction: Added .unsqueeze(0) to flow inputs to satisfy the (sample, batch, dim) shape requirement, fixing AssertionError failures.
  • Argument Safety: Removed the explicit prior kwarg from flow builders to prevent TypeError in Zuko flows, while ensuring bounds checks still pass via the x_dist fallback.

Related Issue

  • Closes #1559

satwiksps avatar Dec 06 '25 16:12 satwiksps

Codecov Report

:white_check_mark: All modified and coverable lines are covered by tests. :warning: Please upload report for BASE (main@3316888). Learn more about missing BASE report.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1711   +/-   ##
=======================================
  Coverage        ?   84.68%           
=======================================
  Files           ?      137           
  Lines           ?    11493           
  Branches        ?        0           
=======================================
  Hits            ?     9733           
  Misses          ?     1760           
  Partials        ?        0           
Flag Coverage Δ
unittests 84.68% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

codecov[bot] avatar Dec 06 '25 16:12 codecov[bot]