Emile Mathieu
Emile Mathieu
>[This](https://github.com/emilemathieu/escnn_jax/blob/43bd9930a4287a22b35ac2bf3cf382b239ef5836/escnn_jax/nn/modules/equivariant_module.py#L101) should probably be an iteration-over-pytree, not just layers only? Also note that you're technically doing O(n^2) work by calling tree_inference at multiple tree depths. Regarding this, I completely agree,...
@patrick-kidger would you have an idea by any chance whether it's usually better/faster in `Jax` when 'filling in an array' to (1) create an empty array, iterate and fill values...
Also to handle both statelful and stateless modules I found myself adding something like ```python for layer in self.layers: if "state" in inspect.signature(layer).parameters: x, state = layer(x, state) else: x...
@Gabri95 ?
Thanks @Gabri95! I've started and got `escnn.nn.Linear` to work (as in tests from `test/test_linear.py` pass) building on the `equinox` module. Trying to get `escnn.nn.Conv2d` to work, and then to be...
Hey @Gabri95, [I gave a try](https://github.com/emilemathieu/escnn_jax) and I can now reproduce the `C8SteerableCNN` on MNIST with a ~20% speed up! There are still quite some things to enhance and modules...
@Gabri95 I'm thinking of making this available for pip install as `escnn_jax` akin to `e3nn_jax`. Would you have any opinion on this?
Happy to eventually works ahah was more than expected and there are still quite some layers not supported, but should be easy to add them taking inspiration from the ones...
@yebai: I've reproduced the code on the issue https://github.com/JuliaLang/julia/issues/19450 you opened, and I do not get any segfault. Yet, it seems that there is still this issue https://github.com/JuliaLang/julia/issues/10441. I can't...
The solution posted here https://discourse.julialang.org/t/how-do-i-deal-with-random-number-generation-when-multithreading/5636/2 seems to be a good solution to deal with random numbers while using`@threads`.