Neil Girdhar
Neil Girdhar
> setting acceleration=False to get rid of the FISTA step I've already set it to false. > taking a constant stepsize or the schedule of your choice (or maybe just...
I guess one thing that would definitely cause this is if the minimization function `f` returning nan would cause JaxOpt's gradient descent to loop forever. I would consider that a...
@zaccharieramzi Okay, thanks so much for your kind help !
@mblondel Would you happen to have some time to look at this blocking bug for me?
@mblondel [It is in a single script](https://github.com/NeilGirdhar/jax_freeze_bug/blob/main/jfb/one.py). (I linked it on the issue.) However, if you do the poetry installation, then it'll guarantee that you have the same environment too.
@mblondel Oh, I'm sorry, `cli` is the entry point. You can run simply append `cli()` to the file to run it.
@mblondel [Yes, that works](https://github.com/google/jax/issues/13864#issuecomment-1396360524). I appreciate your taking a look. I don't think running in 64-bit precision is a reasonable workaround since I don't think Jax should ever lock up...
> The fact that it works in float64 suggests that there might a be a way to work around the stability issue by writing the objective function differently. How does...
@mblondel Sorry to keep bugging you on this. I really want to use JaxOpt, but I'm blocked on this bug. Do you think this is likely to be a problem...
@mblondel Okay thanks, I'll start putting that together for you.