pymc
pymc copied to clipboard
Refactor `EmpiricalGroup` to take `InferenceData` inputs
To decrease our dependence on MultiTrace, and to improve support for workflows were InferenceData objects are saved to disk, the EmpiricalGroup approximation should be refactored.
Currently it takes only MultiTrace inputs, but that's just because it needs access to transformed posterior draws.
Since the approximation is created inside a modelcontext, it should be no problem to get the transforms from the Model and apply them to untransformed posterior draws.
Steps
- Write a function to transform posterior draws given a
pm.Modelinstance and anxarray.Dataset(e.g.idata.posterior) containing untransformed draws. - Refactor
EmpiricalGroupto work with thatxarrayvariable instead of aMultiTrace - Maintain backwards-compatibility (with a deprecation warning) by automatically converting
MultiTraceinhputs toInferenceDatainternally