pymc
pymc copied to clipboard
Fixed UserWarning when converting sample_stats to idata
Description
This PR aims to close #7821. Previously, in some cases, running pm.sample_smc() resulted in a UserWarning being generated:
UserWarning: More chains (5) than draws (1). Passed array should have shape (chains, draws, *shape) warnings.warn(
What I did was update the dict_to_dataset() function call(s) in pymc/backends/arviz.py and pymc/smc/sampling.py. In pymc/backends/arviz.py, I changed the import statement so that dict_to_dataset() was imported from arviz_base rather than arviz.data.base, and the rest of the changes just involved updating the function call. Another important detail is that in pymc/smc/sampling/py, I added another argument when calling dict_to_dataset, which is sample_dims=["chain"], since the variable sample_stats_dict was a 1-D dict.
Related Issue
- [x] Closes #7821
- [ ] Related to #
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [ ] New feature / enhancement
- [x] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):