Kevin Black

Results 11 comments of Kevin Black

+1 to this issue, as it would make the benchmarks much easier to run on setups compatible with the new `tf.distribute.Strategy` API. For now, I was able to hack it...

Update: I implemented FSDP, and now it happens with the train step too, which is a bit more of an issue. Still works on a single VM but not the...

Thanks @skye! I did try with jax/jaxlib 0.4.20 at some point and the same thing happened, as far as I could tell. In collecting the HLO dump, I discovered that...

As I mentioned at the end, this happens even when calling `metadata()`, so I can't get the tree structure of the checkpoint to specify a sharding. It would also be...

@cpgaffney1 I can't re-open the issue, but this seems like a concrete bug that is separate from #646.

Thanks. I was able to work around it with this: ```python if os.path.exists(f"{local_path}/{step}/_sharding"): os.remove(f"{local_path}/{step}/_sharding") manager = ocp.CheckpointManager(local_path, ocp.PyTreeCheckpointer()) structure = manager.item_metadata(step) params = manager.restore( step, restore_kwargs={ "restore_args": jax.tree_map( lambda _:...

Ah I see, I wasn't aware of StandardCheckpointer. I don't think it's covered in the documentation. What are the concrete differences between StandardCheckpointer and PyTreeCheckpointer?

> XLA_PYTHON_CLIENT_PREALLOCATE=false uv run scripts/serve_policy.py Did this solve the issue? FYI, the original post doesn't have any errors, only INFO logs.

Here you go, reproduces on a 4090. You do need a couple of things to trigger the issue -- namely, large enough arrays and a fake "train step". ```python import...

Sure thing. I just tested it with 8 GPUs: - Parallel: ~100s to build each pipeline, 108s total - Sequential: ~50s to build each pipeline, 432s (7 minutes) total The...