dynamax
dynamax copied to clipboard
lgssm.fit_em gives radiacally different behavior on CPU vs TPU
The following test file https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/lgssm_test.py passes on CPU but fails on TPU. On CPU I get
=========================================================================== PASSES ============================================================================
_____________________________________________________ test_sample_and_fit[LinearGaussianSSM-kwargs0-None] _____________________________________________________
-------------------------------------------------------------------- Captured stdout call ---------------------------------------------------------------------
[-488.31033 -435.52692 -433.90012]
________________________________________________ test_sample_and_fit[LinearGaussianConjugateSSM-kwargs1-None] _________________________________________________
-------------------------------------------------------------------- Captured stdout call ---------------------------------------------------------------------
[-471.83078 -376.5532 -367.01227]
On TPU I get
E assert DeviceArray(False, dtype=bool)
E + where DeviceArray(False, dtype=bool) = monotonically_increasing(DeviceArray([-488.46652, -461.02014, -463.8427 ], dtype=float32))
lgssm_test.py:23: AssertionError
> assert monotonically_increasing(lps)
E assert DeviceArray(False, dtype=bool)
E + where DeviceArray(False, dtype=bool) = monotonically_increasing(DeviceArray([-471.99127, -379.4162 , -419.91782], dtype=float32))
lgssm_test.py:23: AssertionError
This notebook also has a unit test for monotoncally increaseing, and seems to pass, at least on a CPU.
https://github.com/probml/dynamax/blob/main/docs/notebooks/linear_gaussian_ssm/lgssm_learning.ipynb
See https://github.com/google/jax/issues/13224