sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Make 'RestrictedTransformForConditional' conform to 'torch.Transform' interface

Open Baschdl opened this issue 1 year ago • 2 comments

RestrictedTransformForConditional is currently typed as torch.Transform having a transform as variable: https://github.com/sbi-dev/sbi/blob/bae69949df20f5616e2b4e9245579d079e09a9e8/sbi/utils/conditional_density_utils.py#L386-L409

The conditioning with theta makes the interface incompatible to torch.Transform as e.g. the normal inv() is called without arguments and our inv(theta) is called with an argument. This example could be fixed by renaming it to restricted_inv(theta).

Baschdl avatar Mar 20 '24 13:03 Baschdl

The problem is caused by the fact that RestrictedTransformForConditional is a torch Transform, but also takes a torch Transform as an argument. This can be refactored, but I also think this is not priority for the release milestone as conditional_potential is not used by anything at the moment.

gmoss13 avatar Aug 20 '24 12:08 gmoss13

Thanks for the clarification.

Then this is related to the recent question here: #1223 and we will fix this when we fix the general handling of custom potentials.

janfb avatar Aug 20 '24 13:08 janfb