Fix progress bar error when nested `CompoundStep` samplers are assigned
#7721 reports an error in the presence of nested CompoundStep. Here's a prettier version of what pymc gives for the example in that issue:
CompoundStep
├─CompoundStep
│ ├─ Metropolis: [a]
│ ├─ Metropolis: [b]
│ └─Metropolis: [c]
└─NUTS: [d]
So there are 4 steps, but there's a compound step on the outside and on the inside. At each step, we get a flat list of 4 dictionaries holding statistics for each step. Currently, the logic for updating the progress bars makes the assumption that the list of step statistics returned at each step matches the list of step samplers. It was assumed that, if there is a compound step, there should only be one, so it can zip over the steps. Here is the display stat update for CompoundStep:
for step_stat, update_fn in zip(step_stats, update_fns):
displayed_stats = update_fn(displayed_stats, step_stat, chain_idx)
The problem is that if there's a nested structure, one of the udpate_fns will do this loop again. Since step_stats does not have the same nested structure, it ends up iterating over dictionary keys and raising the error.
Open to suggestions on how to proceed, because the solution isn't obvious. Opened this as a draft PR to address the problem ASAP, since its been reported 3 times now.
Description
Related Issue
- [ ] Closes #7721 #7724 #
- [ ] Related to https://github.com/pymc-devs/pymc-extras/issues/331
Checklist
- [ ] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [ ] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pymc--7730.org.readthedocs.build/en/7730/
Codecov Report
Attention: Patch coverage is 97.29730% with 1 line in your changes missing coverage. Please review.
Project coverage is 92.82%. Comparing base (
af81955) to head (a8af2e8). Report is 1 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| pymc/step_methods/compound.py | 92.30% | 1 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #7730 +/- ##
==========================================
- Coverage 92.82% 92.82% -0.01%
==========================================
Files 107 107
Lines 18324 18322 -2
==========================================
- Hits 17010 17007 -3
- Misses 1314 1315 +1
| Files with missing lines | Coverage Δ | |
|---|---|---|
| pymc/step_methods/hmc/nuts.py | 97.68% <100.00%> (-0.02%) |
:arrow_down: |
| pymc/step_methods/metropolis.py | 93.20% <100.00%> (-0.03%) |
:arrow_down: |
| pymc/step_methods/slicer.py | 97.32% <100.00%> (ø) |
|
| pymc/step_methods/compound.py | 97.47% <92.30%> (-0.42%) |
:arrow_down: |
🚀 New features to boost your workflow:
- ❄ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Can we have a smoke test that tries a bunch of step samplers for like 10 tune, and which also acts as a regression test for the issue?
I added a solution to the bug, but I think it might be over-engineered. Basically the problem is that once we're inside the sampling loop, we don't have information about which step generated which stats. I assumed the stats could just be paired up with step samplers positionally, but that assumption is broken when there are nested CompoundSteps -- the nesting structure requires that we know more about which stats came from where.
My solution was to add a step_id to steps, which is created at sample time by a step_id_generator function (just itertools.count() by default). The step_id is stored in a meta_info statistic that is not collected or stored as a real sampling statistic. CompoundSteps then choose which update function to apply based on the step_id.
Maybe I'm missing a simpler solution? Hoping this one at least starts a convo to get to somewhere better.
I don't quite get the original problem to know if the solution is reasonable. Sounds to me like the CompoundStep should aggregate the stats for display, so the outer CompoundStep would only see two 2 entries , from NUTS and the inner CompoundStep?
The update to the displayed statistics is done here. The progress manager has access to the return from iter_sample e.g. here. This is the current MCMC point (not helpful) and a "stats" dictionary. The stats dictionary is always a flat list of dictionaries with length equal to the number of assigned samplers. At sampling time, the ProgressManager only ever sees a single step (the outer-most CompoundStep, or the joint BlockedStep over everything). The hierarchical relationships between steps (if any) is destroyed before sampling begins here
Currently, the logic for a CompoundStep is to loop over the steps it contains, and apply each stats update function. So the
If the list of steps is "flat", for example there is only one joint sampler (e.g. NUTS) or if every variable has its own independent step (e.g. Metropolis with blocking=False), this logic works, because the list step methods held by the outer-most CompoundStep will be aligned with the flat stats lists. See here for where this happens.
If, however, the CompoundStep itself holds another CompoundStep, this looping update will be triggered again, and it will try to iterate over stat_dict as if it were the flat list of dicts stats, and we get an error (because it starts iterating over the dictionary keys).
Where does this flattening of stats happen?
I think it's here, due to the use of .extend
@ricardoV94 poke on this, maybe we can do a mini-hack to figure out a better solution than what is proposed in this PR?
@ricardoV94 poke on this, maybe we can do a mini-hack to figure out a better solution than what is proposed in this PR?
Yeah let's try something to unblock this
Closed in favor of #7776