jonb377

Results 35 comments of jonb377

@gobbleturk Is there a way to detect a multihost environment? I see GPU and inference tests failing due to this change.

For the TPU tests, perhaps I should try to fix the backend being initialized - the error is in pyconfig, so I don't expect the backend to be up. The...

Also just a high-level question - @rdyro did you notice any slowdown using gcsfuse compared to a local copy of the weights?

Thanks for reporting! A patch is in the works in https://github.com/google/maxtext/pull/895. As an immediate workaround, you can enable async checkpointing with the config `async_checkpointing=true`, which initializes the jax distributed client.

Awesome! One thing we may need to handle is autocast state with gradient checkpointing - the upstream [restores state using device modules](https://github.com/pytorch/pytorch/blob/1ec05c769b7e1c6ab5ba75f86b4ae6d43d77ac96/torch/utils/checkpoint.py#L301-L304) (e.g. `torch.cuda` or `torch.cpu`), and it [fetches the...