Vincent Roulet
Vincent Roulet
Thank you @noam-delfina for the bug report! For the record, here is a non-trivial example that fails: ``` import jaxopt import jax.numpy as jnp A = jnp.array([[1., 0.]]) b =...
Hello @kclauw, 1. What error do you get exactly? 2. Why are you saying that the issue is with adamw? Adamw does not modify the learning rate internally. Have you...
Hello @kclauw, Sorry for the delayed answer. 1. It could help if you would make the example minimal to reproduce the same error (some dependencies are not defined in what...
Thanks @carlosgmartin, could you add tests? Also you will need to wait for #916 to pass.
Hello @carlosgmartin, Yes, that would be great. Thanks for catching this!
Hello @JadM133, Great point. Could you make a minimum working example (MWE) to pinpoint the error we would get? I'm wondering whether we could simply make our own definition of...
Hello @JadM133, Sorry for the delay, it's been a busy week. Thank you for the detailed example! It really helps. So the issue is that in equinox the model is...
Hello @JadM133, Sorry for the very long delay. Yes, you are absolutely right I got confused. So the solution of using ``` def mask_callable(x): return all(jtu.tree_leaves(jtu.tree_map(lambda e: callable(e), x))) ```...
Fix by #1015 Thanks again @JadM133 ! The fix was neat. I added tests too that mimic Equinox behavior (callable pytrees).
Yep, we know. This bug does not show up internally, neither on a mac, just on linux (which I don't have). It's related to the new jax release. We are...