Brian Patton
Brian Patton
TFP has a JAX-backed impl that's fairly functional at this point (note, some distributions are still coming together). If you wanted to add support for running on JAX as well,...
### Description The test below, when added into `pallas_test.py`, yields the error: ``` jax/_src/lax/control_flow/loops.py", line 1739, in _while_discharge_rule if discharged_consts: raise NotImplementedError(discharged_consts) # changed this line NotImplementedError: [array([[0.]], dtype=float32)] ```...
Still some bugs to be hammered out, AFAIK, but gets us the basic tiling.