pytensor
pytensor copied to clipboard
Support new `jax.random.key` classes
Description
Currently, pytensor is converting the random Generator objects using jax_typify. The random_state is treated as a simple array of uint32s, which is fine for the old jax.random.PRNGKey class. The new typed key class handles things a bit differently, and jax plans to eventually deprecate the old PRNGKey approach. I'm 100% sure that the current approach will be affected by that deprecation or not, but it might be worth to try and adapt the typify code to use the new keys.