pymc-experimental
pymc-experimental copied to clipboard
Implement optimized closed-form gradients for Kalman Filter
I'm quite interested in the results of this paper. The authors derive closed-form gradients for backprop through Kalman Filters. Specifically equations 28-31.
They report a 38x speedup over autodiff gradients from PyTorch. I suspect (with no evidence) that the gradient computations are where the default PyMC sampler really fall down, so this might even make non-JAX sampling of SS models palatable.