pymc
pymc copied to clipboard
Fix progress bar error when nested `CompoundStep` samplers are assigned
#7721 reports an error in the presence of nested CompoundStep. Here's a prettier version of what pymc gives for the example in that issue:
CompoundStep
├─CompoundStep
│ ├─ Metropolis: [a]
│ ├─ Metropolis: [b]
│ └─Metropolis: [c]
└─NUTS: [d]
So there are 4 steps, but there's a compound step on the outside and on the inside. At each step, we get a flat list of 4 dictionaries holding statistics for each step. Currently, the logic for updating the progress bars makes the assumption that the list of step statistics returned at each step matches the list of step samplers. It was assumed that, if there is a compound step, there should only be one, so it can zip over the steps. Here is the display stat update for CompoundStep:
for step_stat, update_fn in zip(step_stats, update_fns):
displayed_stats = update_fn(displayed_stats, step_stat, chain_idx)
The problem is that if there's a nested structure, one of the udpate_fns will do this loop again. Since step_stats does not have the same nested structure, it ends up iterating over dictionary keys and raising the error.
Open to suggestions on how to proceed, because the solution isn't obvious. Opened this as a draft PR to address the problem ASAP, since its been reported 3 times now.
Description
Related Issue
- [ ] Closes #7721 #7724 #
- [ ] Related to https://github.com/pymc-devs/pymc-extras/issues/331
Checklist
- [ ] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [ ] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pymc--7730.org.readthedocs.build/en/7730/