pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Fix progress bar error when nested `CompoundStep` samplers are assigned

Open jessegrabowski opened this issue 8 months ago • 9 comments

#7721 reports an error in the presence of nested CompoundStep. Here's a prettier version of what pymc gives for the example in that issue:

CompoundStep
├─CompoundStep
│   ├─ Metropolis: [a]
│   ├─ Metropolis: [b]
│   └─Metropolis: [c]
└─NUTS: [d]

So there are 4 steps, but there's a compound step on the outside and on the inside. At each step, we get a flat list of 4 dictionaries holding statistics for each step. Currently, the logic for updating the progress bars makes the assumption that the list of step statistics returned at each step matches the list of step samplers. It was assumed that, if there is a compound step, there should only be one, so it can zip over the steps. Here is the display stat update for CompoundStep:

            for step_stat, update_fn in zip(step_stats, update_fns):
                displayed_stats = update_fn(displayed_stats, step_stat, chain_idx)

The problem is that if there's a nested structure, one of the udpate_fns will do this loop again. Since step_stats does not have the same nested structure, it ends up iterating over dictionary keys and raising the error.

Open to suggestions on how to proceed, because the solution isn't obvious. Opened this as a draft PR to address the problem ASAP, since its been reported 3 times now.

Description

Related Issue

  • [ ] Closes #7721 #7724 #
  • [ ] Related to https://github.com/pymc-devs/pymc-extras/issues/331

Checklist

Type of change

  • [ ] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

📚 Documentation preview 📚: https://pymc--7730.org.readthedocs.build/en/7730/

jessegrabowski avatar Mar 20 '25 08:03 jessegrabowski

Codecov Report

Attention: Patch coverage is 97.29730% with 1 line in your changes missing coverage. Please review.

Project coverage is 92.82%. Comparing base (af81955) to head (a8af2e8). Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pymc/step_methods/compound.py 92.30% 1 Missing :warning:
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7730      +/-   ##
==========================================
- Coverage   92.82%   92.82%   -0.01%     
==========================================
  Files         107      107              
  Lines       18324    18322       -2     
==========================================
- Hits        17010    17007       -3     
- Misses       1314     1315       +1     
Files with missing lines Coverage Δ
pymc/step_methods/hmc/nuts.py 97.68% <100.00%> (-0.02%) :arrow_down:
pymc/step_methods/metropolis.py 93.20% <100.00%> (-0.03%) :arrow_down:
pymc/step_methods/slicer.py 97.32% <100.00%> (ø)
pymc/step_methods/compound.py 97.47% <92.30%> (-0.42%) :arrow_down:
🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Mar 20 '25 09:03 codecov[bot]

Can we have a smoke test that tries a bunch of step samplers for like 10 tune, and which also acts as a regression test for the issue?

ricardoV94 avatar Mar 20 '25 10:03 ricardoV94

I added a solution to the bug, but I think it might be over-engineered. Basically the problem is that once we're inside the sampling loop, we don't have information about which step generated which stats. I assumed the stats could just be paired up with step samplers positionally, but that assumption is broken when there are nested CompoundSteps -- the nesting structure requires that we know more about which stats came from where.

My solution was to add a step_id to steps, which is created at sample time by a step_id_generator function (just itertools.count() by default). The step_id is stored in a meta_info statistic that is not collected or stored as a real sampling statistic. CompoundSteps then choose which update function to apply based on the step_id.

Maybe I'm missing a simpler solution? Hoping this one at least starts a convo to get to somewhere better.

jessegrabowski avatar Mar 23 '25 17:03 jessegrabowski

I don't quite get the original problem to know if the solution is reasonable. Sounds to me like the CompoundStep should aggregate the stats for display, so the outer CompoundStep would only see two 2 entries , from NUTS and the inner CompoundStep?

ricardoV94 avatar Mar 24 '25 09:03 ricardoV94

The update to the displayed statistics is done here. The progress manager has access to the return from iter_sample e.g. here. This is the current MCMC point (not helpful) and a "stats" dictionary. The stats dictionary is always a flat list of dictionaries with length equal to the number of assigned samplers. At sampling time, the ProgressManager only ever sees a single step (the outer-most CompoundStep, or the joint BlockedStep over everything). The hierarchical relationships between steps (if any) is destroyed before sampling begins here

Currently, the logic for a CompoundStep is to loop over the steps it contains, and apply each stats update function. So the

If the list of steps is "flat", for example there is only one joint sampler (e.g. NUTS) or if every variable has its own independent step (e.g. Metropolis with blocking=False), this logic works, because the list step methods held by the outer-most CompoundStep will be aligned with the flat stats lists. See here for where this happens.

If, however, the CompoundStep itself holds another CompoundStep, this looping update will be triggered again, and it will try to iterate over stat_dict as if it were the flat list of dicts stats, and we get an error (because it starts iterating over the dictionary keys).

jessegrabowski avatar Mar 24 '25 12:03 jessegrabowski

Where does this flattening of stats happen?

ricardoV94 avatar Mar 25 '25 07:03 ricardoV94

I think it's here, due to the use of .extend

jessegrabowski avatar Mar 26 '25 02:03 jessegrabowski

@ricardoV94 poke on this, maybe we can do a mini-hack to figure out a better solution than what is proposed in this PR?

jessegrabowski avatar May 02 '25 05:05 jessegrabowski

@ricardoV94 poke on this, maybe we can do a mini-hack to figure out a better solution than what is proposed in this PR?

Yeah let's try something to unblock this

ricardoV94 avatar May 02 '25 17:05 ricardoV94

Closed in favor of #7776

jessegrabowski avatar Jul 10 '25 07:07 jessegrabowski