Lux.jl
Lux.jl copied to clipboard
Taking PRNGs seriously
Currently, we have very rudimentary handling of stochastic layers. Initialization of RNGs for stochastic layers is done as:
randn(rng, 1)
return (rng=replicate(rng), training=true)
This makes stochastic layers start from different RNGs. Need to look at how jax frameworks do it
JAX benefits from a natively splittable, immutable RNG interface. To my knowledge there is no such equivalent in the Julia ecosystem, but were there one you could envision how the splitting process might work.
I found something similar that exists in Julia. We will have to shift to Random123.jl and use https://github.com/SciML/DiffEqNoiseProcess.jl/blob/c48cdce099cece1edbd8f99da960bc67e3c2c4ca/src/noise_interfaces/virtual_brownian_tree_interface.jl#L139-L148
https://github.com/UBC-Stat-ML/SplittableRandoms.jl/