IMNN
IMNN copied to clipboard
Updating to JAX v0.4.x
The current upstream version of IMNN is no longer compatible with 2023 JAX because it relied on some internal APIs and also because of the change to jnp.Array and the removal of the jax.ops.... operators.
I started fixing a few things, and with this branch the first notebook imnn_vs_regression.ipynb works again.
I'm currently a bit stumped by the 2d field inference notebook, mostly I think because I don't have a ton of experience with stax.
I'd also recommend moving away from jax.example_libraries.optimizers and jax.example_libraries.stax and use the more common flax and optax libraries instead. But that's a slightly different question...
For now I'm only opening this PR as draft, to keep track of this, if people want to help, they are most welcome :-)
Pinging some of the usual suspects @tlmakinen @lavaux
@EiffL @tomcharnock I've got an updated version on my end in a bitbucket branch -- been meaning to update the default for a while -- do you have a usecase in mind ?
@EiffL @tomcharnock I've got an updated version on my end in a bitbucket branch -- been meaning to update the default for a while -- do you have a usecase in mind ?
Awesome :-) I thought that might have been the case, would you be up for merging any additional updates with this branch so that we can get this repo up to date to upstream JAX? (Assuming this is still the main reference for this code...)
And the concrete usecase I have in mind is in the context of a project we are doing with @Justinezgh and @dlanzieri (based on this code https://github.com/DifferentiableUniverseInitiative/sbi_lens) to compare various approaches to train a summary statistic from data. Right now I'm just curious to see whether it would be possible to train our baseline ResNet 18 with an IMNN or if it would blow up memory-wise. Ideally I'd like to be able to use a Flax or Haiku network (to make sure I use exactly the same neural architectures under different training strategies).
But that's a different question from the question of helping to keep this code functional :-)