Add a common MDN network interface (consistent with new DensityEstimators)
Description
The MDN (Mixture Density Network) — a Gaussian Mixture Density network — currently receives several "special" treatments in the codebase:
- SNPE_A relies on it and implements a custom
SNPE_A_MDNwrapper with specific corrections. - SNPE_C uses a specialized loss function for MDN.
conditional_density.pydefines a customConditionedMDNto handle "marginalization" properties.
Given these various custom implementations, it would be beneficial to consolidate all MDN-related wrappers and functionalities into a single class: MDNDensityEstimator. This new class will adhere to the existing DensityEstimator structure, providing a more unified and streamlined approach.
Advantages:
- Reduced Code Duplication: By consolidating the MDN-related code, we eliminate repetitive implementations and centralize the logic.
- Improved Typing: A unified class will provide more consistent and clear typing across the library.
- Simplified Maintenance: Consolidating MDN functionality into one place will make future updates, bug fixes, and enhancements easier to manage.
This refactor will improve the readability, maintainability, and scalability of the codebase.
For a first approach see #1042 .
I find the SNPE_A_MDN class quite confusing. Is there a loss implemented for it yet? Why does it only inherit from DensityEstimator but not implement all methods?
This is being worked on in #1042 but it will not make it into the release in August.