Filipp Fisin
Filipp Fisin
With orbax patching I've actually implemented JIT-less checkpoint save/write in distributed CPU-backend setup. Patching `sync_global_devices` and `broadcast_one_to_all` was enough.
> Ah ok nice, so this is solved for you? More or less, yes. > We might look into a way of doing this without patching as part of our...
> so are you just wanting COMMIT_SUCCESS.txt file in non-GCS contexts Yes, exactly. Writing straight to final directory and marking writing as "finished" with `COMMIT_SUCCESS.txt file`. > Having two different...
Ok, thank you!
Yeah, I don't see how this KEP would help. Actually, after a quick reading of this KEP, I think it would introduce the same problem I'm experiencing with JobSet in...
Totally agree with you, I'm currently using hand-crafted argo workflow for launching multi-node training which also requires force deleting pods stuck in Terminating state which just deletes them from etcd...
Hi guys, I'm also wondering how can I approach it, any news? @cgarciae @gabbard
I've just upgraded to latest flax (0.12.0) and seems like now nnx.jitted function now has .lower exposed so I guess this issue can be marked as resolved
I believe that it would implicate specific ordering on calling plugins' predicateFn which might be unwanted behavior, because if we'll allow it for one plugin, then we have to allow...
Here is some toy repro tested on JAX 0.4.34 ``` import flax.linen as nn import jax import jax.ad_checkpoint import jax.numpy as jnp import numpy as np from flax.linen.linear import default_kernel_init...