dynamax icon indicating copy to clipboard operation
dynamax copied to clipboard

lgssm.fit_em gives radiacally different behavior on CPU vs TPU

Open murphyk opened this issue 3 years ago • 2 comments

The following test file https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/lgssm_test.py passes on CPU but fails on TPU. On CPU I get

=========================================================================== PASSES ============================================================================
_____________________________________________________ test_sample_and_fit[LinearGaussianSSM-kwargs0-None] _____________________________________________________
-------------------------------------------------------------------- Captured stdout call ---------------------------------------------------------------------
[-488.31033 -435.52692 -433.90012]
________________________________________________ test_sample_and_fit[LinearGaussianConjugateSSM-kwargs1-None] _________________________________________________
-------------------------------------------------------------------- Captured stdout call ---------------------------------------------------------------------
[-471.83078 -376.5532  -367.01227]

On TPU I get

E       assert DeviceArray(False, dtype=bool)
E        +  where DeviceArray(False, dtype=bool) = monotonically_increasing(DeviceArray([-488.46652, -461.02014, -463.8427 ], dtype=float32))

lgssm_test.py:23: AssertionError
>       assert monotonically_increasing(lps)
E       assert DeviceArray(False, dtype=bool)
E        +  where DeviceArray(False, dtype=bool) = monotonically_increasing(DeviceArray([-471.99127, -379.4162 , -419.91782], dtype=float32))

lgssm_test.py:23: AssertionError

murphyk avatar Nov 09 '22 20:11 murphyk

This notebook also has a unit test for monotoncally increaseing, and seems to pass, at least on a CPU.

https://github.com/probml/dynamax/blob/main/docs/notebooks/linear_gaussian_ssm/lgssm_learning.ipynb

murphyk avatar Nov 09 '22 22:11 murphyk

See https://github.com/google/jax/issues/13224

murphyk avatar Nov 14 '22 05:11 murphyk