pymc
pymc copied to clipboard
Fix tuning of CompoundStep with multiple chains in serial
When sampling a model using Compound step in serial the first chain runs with tuning, but for subsequent chains individual sampler has tune parameter set to False.
The problem seems to be in sampling._iter_sample() here. The tune flag is set, but CompoundStep class does not have that variable so it doesn't get propagated to individual samplers. It does have stop_tuning which gets triggered on the first chain and stops tuning for all subsequent chains.
Now, there is also reset_tuning which in BaseHMC sets tune=True here. However, Metropolis resets all tuning parameters other than value of tune itself. A more consistent approach might be to store the original value of tune and use that in reset_tuning. Then the line in sampling._iter_sample() where tune is set could be removed.
But, I'm not sure how this would work with all the other samplers.
Instead, to keep the changes minimal I added a tune property to CompoundStep. The setter will set any tune variables in individual steppers if it exists. And the getter checks if any of the steppers has tuning on.
This was good enough for my purposes, but I'm still not sure whether it solves all issues. In _iter_sample here step.iter_count is set to 0, if it exists. That variable does not exist in CompoundStep, but does exist in BaseHMC.
There are two ways we can solve this
- Add
iter_countas property toCompoundStepsame as I did withtune - Update
reset_tuningto setiter_count = 0and possibly remove that line from_iter_sample
I think option 2 would be a cleaner solution. If we go that route than tune should be set to original value when resetting as well. I think I'll implement that, unless there are other suggestions.
Thank your for opening a PR!
Depending on what your PR does, here are a few things you might want to address in the description:
- [x] what are the (breaking) changes that this PR makes?
- [x] important background, or details about the implementation
- [x] are the changes—especially new features—covered by tests and docstrings?
- [x] linting/style checks have been run
- [x] consider adding/updating relevant example notebooks
- [x] right before it's ready to merge, mention the PR in the RELEASE-NOTES.md
Some samplers have a tune variable and no reset_tuning method. That's why _iter_sample() cannot rely on calling reset_tuning() only.
Note, DEMetropolis only accepts tune values in [None, "scaling", "lambda]", but _iter_sample sets it to bool. As a result, tuning won't be performed during a step, see here. This has been fixed in DEMetropolisZ by introducing tune_type parameter and making tune boolean. This fix should be copied in a separate PR
Variable iter_count is used by NUTS to effect max tree depth during tuning and by base_hmc in warning messages here and here.
It can be set to 0 in reset_tuning as well. Although, it seems what _iter_sample needs is a general step.reset() method which brings it back to the original configuration. However, for most samplers reset() will simply call reset_tuning(), except for NUTS which would also set iter_count=0. It's not worth adding extra layer of complexity for only one step sampler.
Funny enough BaseHMC has a reset method, which is called by reset_tuning. Exactly the opposite mechanism that I described above. Anyway, that's the natural place to set iter_count=0.
Hi @msibaev ,
Sorry about the delay. I just got the chance to read your descriptions and look at the changes. I believe the history was such that in the beginning there was only a boolean tune parameter, but the typing of that interface diffused away over time.
When I implemented DEMetropolis the boolean setting was no longer enough, because the sampler had to reset some things when switching in/out the tuning phase. If I remember correctly that's when I introduced the reset_tuning method.
At the _iter_sample level we can't deal with all the specifics of the different samplers, so yes we should unify that to one interface and get rid of the hasattr() things.
Hi @michaelosthege
I don't get much time to look at these things nowadays, sorry for the delay.
Thanks for the context of rest_tuning(). I guess there should be some more standardization across all samplers in that respect, which would allow _iter_sample() to be cleaned up a bit more. But, that probably goes beyond the scope of this PR. My main goal was just to fix tuning when multiple chains are running in serial. It's an edge case, but an important one for windows machines where multiprocessing can be buggy.
Is there anything else that you think should be done here?
Hi @msibaev , I see you're working on this branch again. Cool! Let us know when it's ready for review
Would be good to open an issue for the problem this PR was trying to address, as it has been stale for a while
Closing this due to lack of activity and the fact it is targeting V3