pymc
pymc copied to clipboard
Don't use deprecated batched_dot
Description
This PR replaces the deprecated pt.batched_dot function with the preferred pt.sum operation in the KroneckerNormal distribution's logprob calculation, addressing issue #7878.
Problem
The current implementation uses batched_dot, which is deprecated in PyTensor and triggers warnings. Deprecated functions may lead to future breakage and lower performance.
Solution
Refactored the KroneckerNormal logprob code to use pt.sum with explicit axis parameters, achieving the same functionality without relying on deprecated APIs.
Tests
- All existing tests related to KroneckerNormal pass successfully.
- No deprecation warnings are shown when running the updated code.
- Validated numerical equivalence with the previous implementation.
Related Issue
Fixes #7878
Checklist
- [x] Follows PyMC contributing guidelines
- [x] All tests pass locally
- [x] No API-breaking changes
📚 Documentation preview 📚: https://pymc--7951.org.readthedocs.build/en/7951/