numpyro
numpyro copied to clipboard
Fixes `random_flax_module` with `flax.linen.BatchNorm`
Fixes https://github.com/pyro-ppl/numpyro/issues/1446
The first two commits are from an old branch 🤦 . We can squash and merge instead.
@fehiepsi should I leave _substitute_default_key in utils.py or is handlers.py a better place?
Leaving it in utils sounds reasonable to me. It is just a workaround for the edge case.
Thanks for fixing the issue!