pyro icon indicating copy to clipboard operation
pyro copied to clipboard

Loaded model cannot reproduce the predictions of the saved model

Open kaltinel opened this issue 2 years ago • 8 comments

Hi, I am trying to save and load my model, and I believe there is a problem regarding its implementation. I train my model and test on a dataset. Then I retrieve some diagnostics for its performance (precision, recall etc). Then, I save my model. After, I load my model to test on the same exact test dataset to control if my saving / loading is correct. And I believe it is somehow not, as the diagnostics does not match up with each other.

For example, the accuracy metrics for the trained model for three epochs are:

Precision       |      Recall   |   Specificity   |      FPR        | F1_Score      |  Accuracy
0.70            |      0.63     |       0.40      |     0.59        |     0.66      |    0.56
0.68            |      0.65     |       0.31      |     0.68        |     0.66      |    0.55
0.68            |      0.48     |       0.49      |     0.50        |     0.56      |    0.48 *

After I save this model, I load the model, guide and optimizer on a different script and try Predictive on the same test data. My results from the loaded model are as shown:

Precision | Recall | Specificity | FPR      | F1_Score  | Accuracy
0.70      | 0.51   | 0.51        | 0.48     | 0.59      | 0.51

I would expect the last line of the trained model (with *) on the test data to be same as the loaded model's results, as they are in the same seed, tested on the same data. And of course, the loaded model doesn't get further training, it only gets Predictive used on it once.

I save my model as:

torch.save({"model" : mymodels.state_dict(), "guide" : mymodels.guide}, path_to_model) # save the model and guide params
pyro.get_param_store().save(path_to_parameters) # save parameters
adam.save(path_to_optim) # save the optimizer

And I load it as:

pyro.clear_param_store() ## clear
## load 
saved_model_dict = torch.load(path_to_model) 
mymodels.load_state_dict(saved_model_dict['model']) 
mymodels.guide = saved_model_dict['guide']

pyro.get_param_store().load(path_to_parameters)
adam.load(path_to_optim)
svi = SVI(mymodels.model, config_enumerate(mymodels.guide, "parallel", expand=True), adam, TraceEnum_ELBO(max_plate_nesting=2, strict_enumeration_warning=False))

And then call Predictive: predictive = Predictive(svi.model, guide=svi.guide, num_samples=args.numsampling) I cannot understand why I see this behaviour, and think it may be a bug.. I am glad to have your insights.Thank you.

kaltinel avatar Sep 07 '22 15:09 kaltinel

Hi @kaltinel just to clarify, can you confirm that you are calling pyro.set_rng_seed(...my_seed...) before running predictive(...) in each case?

Also, is there a reason you're loading adam and creating an svi instance, rather than directly constructing the predictive?

predictive = Predictive(model, guide, num_samples=args.num_samples)

fritzo avatar Sep 07 '22 18:09 fritzo

Hi, thank you for your answer. Yes, the seed was the first possible culprit I thought of, so I confirm that the pyro.set_rng_seed(#) is the same in each case, and has been run before running Predictive().

For predictive construction, there is no specific reason of me doing as such. I thought I would get my model,guide,and the parameters and the optimizer saved, so that I can load it with all required elements. Would you suggest that I would go for this to save:

torch.save({"model" : mymodels.state_dict(), "guide" : mymodels.guide}, path_to_model) # save the model and guide params
pyro.get_param_store().save(path_to_parameters) # save parameters

And, to load:

saved_model_dict = torch.load(path_to_model) 
mymodels.load_state_dict(saved_model_dict['model']) 
mymodels.guide = saved_model_dict['guide']
pyro.get_param_store().load(path_to_parameters)
predictive = Predictive(mymodels.model, guide=mymodels.guide, num_samples=args.numsampling)

If so, can you explain why? (because no need for optimizer, no need for it to be loaded etc?)

kaltinel avatar Sep 08 '22 08:09 kaltinel

Would you suggest

Whatever works for you, I was only looking for a minimal reproducible example and would have expected something like

torch.save(model, path_to_model)
torch.save(guide, path_to_guide)
pyro.get_param_store().save(path_to_parameters)

pyro.clear_param_store()

model = torch.load(path_to_model)
guide = torch.load(path_to_guide)
pyro.get_param_store().load(path_to_parameters)

predictive = Predictive(model, guide=guide, num_samples=args.numsampling)

But again, whatever works for you.

Back to your main problem 🙂 could you try cranking up the number of samples? That should help distinguish whether the problem is due to random noise or actually bad parameters. If your results differ with large num_samples, then we could try to diff the param store before and after, e.g.:

def get_snapshot():
    return {k: v.detach().clone() for k, v in pyro.get_param_store().items()}

snapshot1 = get_snapshot()
...save however you like...
pyro.clear_param_store()
del model, guide, svi, predictive  # ensure a fresh environment
...load however you like...
snapshot2 = get_snapshot()

# check identity
assert set(snapshot1) == set(snapshot2), "keys differ"
for k, v1 in snapshot1.items():
    v2 = snapshot2[k]
    assert torch.allclose(v1, v2), f"mismatch at key {repr(k)}"

if that fails we could try something similar with snapshots of model.named_parameters() and guide.named_parameters().

Thanks for diving into the debugging!

fritzo avatar Sep 08 '22 13:09 fritzo

Hi, Thank you for your detailed answer. I tried increasing the number of sampling and made sure that test dataset is -for-sure- the same, by making an external variable that parses the test file (apologies for not being able to increase the number of samples for this time point), and the diagnostics are still different from each other...

Then I tried your suggestion of comparing the parameters of the param_store, and the snapshots are the same, which I believe is great! However, it made me perhaps more confused as now I really cannot comprehend why I have different Predictive results from the same dataset, same model and guide, and same parameters....

I look forward to your feedback, thank you!

kaltinel avatar Sep 09 '22 15:09 kaltinel

Hmm, I'm not sure what else to try. Is there any way you could make a reproducible example we could look at?

fritzo avatar Sep 09 '22 16:09 fritzo

Hi, I am sorry for the time lapse in between replies. I was trying to solve the issue. (Unfortunately I am unable to share my code due to privacy regulations on the matter.)

However I possibly found the culprit: I saved the model, guide and parameters of the model in the 'testing' script and compared with the one that is loaded from the training script. The guide and parameters are the matching with each other, however the model which loaded to testing script (the one generated during training) is not the same with the one which is saved from the testing script.

I am puzzled: How come I can save / load the guide and parameters exactly the same, but not the model?..I cant see how model changes during testing..

As I mentioned in my first comment, I use mymodels.state_dict() to save and mymodels.load_state_dict(saved_model_dict['model']) to load.

I look forward to have your insights on the matter.

Thank you.

kaltinel avatar Sep 27 '22 12:09 kaltinel

What are the differences between your original model and saved-then-loaded model? Do you save any randomly-generated tensors in the model? Have you considered using torch.save() and torch.load() for the whole mymodel (parameters and guide), as in

torch.save(mymodel, "mymodel.pt")
mymodel2 = torch.load("mymodel.pt")

fritzo avatar Oct 02 '22 18:10 fritzo

Thank you for your reply.

I indeed saved my model with torch.save() and the issue was, at the end of the day, the seed and the position of SVI instance.

  • I was initiating my SVI before I load the pyro parameters, which should not be the case.
  • I was calling Predictive inside of a for loop, for each epoch, and it was changing the seed hence the output.

Maybe you can update your documentation regarding this? The documentation of pyro model loading / saving is - to my knowledge- not easy to be found, and I found these information from colleagues. I would appreciate to have more detailed documentation. I am sure other users will, too :)

kaltinel avatar Oct 06 '22 15:10 kaltinel