pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Support new `jax.random.key` classes

Open lucianopaz opened this issue 9 months ago • 0 comments

Description

Currently, pytensor is converting the random Generator objects using jax_typify. The random_state is treated as a simple array of uint32s, which is fine for the old jax.random.PRNGKey class. The new typed key class handles things a bit differently, and jax plans to eventually deprecate the old PRNGKey approach. I'm 100% sure that the current approach will be affected by that deprecation or not, but it might be worth to try and adapt the typify code to use the new keys.

lucianopaz avatar Mar 20 '25 21:03 lucianopaz