sbi icon indicating copy to clipboard operation
sbi copied to clipboard

VIPosterior object can't save

Open WangYuxiang8 opened this issue 2 years ago • 4 comments

Hi team

I try to save vi posterior object into file using torch.save, as below, but throw an error.

for i in range(num_rounds):

    theta, x = simulate_for_sbi(simulator, proposal, num_simulations=500)

    _ = inference.append_simulations(theta, x).train()

    posterior = inference.build_posterior(sample_with="vi", vi_parameters=vi_parameters)
    posterior = posterior.set_default_x(x_o)

    posterior.train(**vi_parameters)
    proposal = posterior.set_default_x(x_o)

    posteriors.append(posterior)
torch.save(posteriors[-1], f"vi_posterior.pt")

AttributeError: Can't pickle local object 'make_object_deepcopy_compatible.<locals>.__deepcopy__'

btw, rejection posterior and mcmc posterior can save correctly using torch.save.

Best wishes.

WangYuxiang8 avatar Apr 26 '22 02:04 WangYuxiang8

Hey @WangYuxiang8,

thanks, for reporting this! Sorry, something we forgot to test :/ I will have a look to fix this. This function should make the objects compatible with the deepcopy protocol but seems to cause issues with pickle (actually not only this one).

In the meantime, you can do the following to save the posteriors.

posterior._optimizer = None
posterior.__deepcopy__ = None
posterior._q_build_fn = None
posterior._q.__deepcopy__ = None
torch.save(posterior, "test.pkl")

This should let you save it and shouldn't affect the loaded object as they will just be initiated then again.

Kind regards Manuel

Edit: Adding this to the VIPosterior in "vi_posterior.py" makes it more convenient

    def __getstate__(self):
        """ Makes it pickle compatible """
        self._optimizer = None
        self.__deepcopy__ = None
        self._q_build_fn = None
        self._q.__deepcopy__ = None
        return self.__dict__

Yet, I think the missing build_fn can cause problems in the loaded object (Should be fixed if you again set your q).

manuelgloeckler avatar Apr 26 '22 09:04 manuelgloeckler

Hi @manuelgloeckler

Thanks, it works for me!

Best wishes.

WangYuxiang8 avatar Apr 26 '22 14:04 WangYuxiang8

this seems to be solved. Feel free to re-open if needed.

janfb avatar Jun 21 '22 09:06 janfb

Let's keep this open until the __getstate__ is merged.

michaeldeistler avatar Jun 21 '22 10:06 michaeldeistler