sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Add a common MDN network interface (consistent with new DensityEstimators)

Open manuelgloeckler opened this issue 1 year ago • 2 comments

Description

The MDN (Mixture Density Network) — a Gaussian Mixture Density network — currently receives several "special" treatments in the codebase:

  • SNPE_A relies on it and implements a custom SNPE_A_MDN wrapper with specific corrections.
  • SNPE_C uses a specialized loss function for MDN.
  • conditional_density.py defines a custom ConditionedMDN to handle "marginalization" properties.

Given these various custom implementations, it would be beneficial to consolidate all MDN-related wrappers and functionalities into a single class: MDNDensityEstimator. This new class will adhere to the existing DensityEstimator structure, providing a more unified and streamlined approach.

Advantages:

  • Reduced Code Duplication: By consolidating the MDN-related code, we eliminate repetitive implementations and centralize the logic.
  • Improved Typing: A unified class will provide more consistent and clear typing across the library.
  • Simplified Maintenance: Consolidating MDN functionality into one place will make future updates, bug fixes, and enhancements easier to manage.

This refactor will improve the readability, maintainability, and scalability of the codebase.

For a first approach see #1042 .

manuelgloeckler avatar Mar 13 '24 16:03 manuelgloeckler

I find the SNPE_A_MDN class quite confusing. Is there a loss implemented for it yet? Why does it only inherit from DensityEstimator but not implement all methods?

bkmi avatar Mar 19 '24 10:03 bkmi

This is being worked on in #1042 but it will not make it into the release in August.

janfb avatar Jul 22 '24 07:07 janfb