liesel
liesel copied to clipboard
Replace Python loops in LieselInterface with native Jax
I keep experiencing relatively long compile times. I suspect, the Python loops in LieselInterface.update_state
, and possibly in LieselInterface.extract_position
might be partly responsible for this, see this point in the Jax FAQ. I have not looked into it in detail, but I hope these loops can be replaced by using something like jax.lax.scan
.